Hooks
Altre callback messe a disposizione da pytorch, sono i cosidetti hooks.
Mentre le callback viste finora possono essere eseguite in corrispondenza di punti specifici del training loop (after_batch, before_fit etc..),
gli hooks possono essere applicati ad un modulo o ad un tensore.
Module Hooks
Per un modulo possiamo registrare un hook in 3 punti:
- forward prehook eseguito prima del forward (register_forward_pre_hook())
- forward eseguito dopo il forward (register_forward_hook())
- backward eseguito dopo il backward (register_backward_hook())
le funzioni registrabili devono accettare 3 parametri, il modulo e i suoi input ed output, e ritornare module_output modificato o None
def func_hook(module, module_input, module_output)
# usage ex.:
module.register_forward_hook(func_hook)
Tensor Hooks
Per un tensore possiamo registrare un hook solo per:
- forward eseguito dopo il forward, dopo che sono stati calcolati i gradienti del tensore (register_hook())
le funzioni registrabili sono del tipo:
def func_hook(grad)
# usage ex.:
x_tensor.register_hook(func_hook)
nota per rimuovere un hook:
d = c.register_hook(func_hook)
d.remove()