A framework for improved handling of PyTorch module hooks
pip install tacklebox
PyTorch module hooks are useful for a number of reasons. Debugging module behavior, quickly altering processing or gradient flow, and studying intermediate activations are just a few utilities. Module hooks are a powerful tool, but using them requires keeping track of hook functions, the modules they are registered to, and the handles that let you remove those hooks.
TackleBox maintains a record of all hooks that have been registered, the modules they were registered with, and allows you to deactivate and reactivate any previously registered hooks on the fly.
PyTorch autograd can lead to inconsistencies in what gradients are
served to backward hooks registered with PyTorch's
module.register_backward_hook
method. This inconsistency has ultimately led
to the method's deprecation altogether.
TackleBox reimplements module backward hook registration using Tensor backward hooks registered on a module's inputs and outputs during the forward pass. This allows us to establish correspondence between the gradient tensors received by the backward hook and the input/output tensors received by the forward hook.
With TackleBox you can continue to use module backward hooks, even with older PyTorch versions ( < 1.8.0 ), and benefit from consistency in the ordering of gradient tensors served.
Hook functions must follow the call signature of the corresponding hook type:
def my_forward_hook(module, inputs, outputs):
print('Finished forward pass for module %s' % module.name)
return outputs
def my_forward_pre_hook(module, inputs):
print('Beginning forward pass for module %s' % module.name)
return inputs
def my_backward_hook(module, grad_in, grad_out):
print('Finished backward pass for module %s' % module.name)
The inputs
and outputs
passed to forward hooks and forward pre-hooks are
tuples containing all tensors passed to that module and output by that module, respectively.
grad_in
and grad_out
are tuples of the same length as inputs
and outputs
,
respectively. The element at each position in grad_in
or grad_out
is the gradient
w.r.t. the input or output
at the same position in inputs
or outputs
, respectively.
Forward hooks and forward-pre hooks may optionally return a tuple of the same length with 1 or more of the tensors in it altered. These new inputs or outputs will be passed to the module's forward pass (in the case of the forward pre-hook) or returned as the module's output (in the case of the forward hook)
Forward hooks will be called at the end of a module's forward pass, forward pre-hooks will be called immediately before a module's forward pass and backward hooks will be called at the end of a module's backward pass.
Using PyTorch module hooks, hook registration might look something like this:
my_handle = my_module.register_forward_hook(my_forward_hook)
other_handle = other_module.register_forward_hook(my_forward_hook)
...
In this case you must maintain a hook handle for each new module and hook function you decide to register.
With TackleBox, rather than registering hooks on a module directly, hook functions are passed to the hook manager along with any modules that you would like it to be registered on. Hook functions should be passed to the corresponding registration function and assigned an id using kwargs:
from tacklebox.hook_management import HookManager
hookmngr = HookManager()
# register a forward hook on my_module, other_module, etc.
hookmngr.register_forward_hook(my_forward_hook,
my_module_name=my_module,
other_module_name=other_module,
**more_named_modules)
# register a forward pre-hook on my_module, other_module, etc.
hookmngr.register_forward_pre_hook(my_forward_pre_hook,
my_module_name=my_module,
other_module_name=other_module,
**more_named_modules)
# register a backward hook on my_module, other_module, etc.
hookmngr.register_backward_hook(my_backward_hook,
my_module_name=my_module,
other_module_name=other_module,
**more_named_modules)
Note that there is no need to maintain any additional references, other than that of the hook manager.
Once registered, your hooks can be activated and deactivated on the fly,
using a variety of filtering options. By default, hooks are activated upon registration.
To register a hook without immediately activating it, pass activate=False
:
hookmngr.register_forward_hook(my_forward_hook, my_module_name=my_module,
activate=False)
Following registration, you can select groups of hooks to activate or deactivate using several different filters:
# activate/deactivate all hooks
hookmngr.activate_all_hooks()
hookmngr.deactivate_all_hooks()
# filter by module:
# activate/deactivate all hooks registered to my_module, other_module, etc.
hookmngr.activate_module_hooks(my_module, other_module, *more_modules)
hookmngr.deactivate_module_hooks(my_module, other_module, *more_modules)
# filter by function:
# activate/deactivate my_forward_hook on all modules it's been registered with
hookmngr.activate_all_hook(hook_types=[my_forward_hook])
hookmngr.deactivate_all_hooks(hook_types=[my_forward_hook])
# filter by hook category
# activate/deactivate all forward hooks that have been registered
hookmngr.activate_all_hooks(category='forward_hook')
hookmngr.deactivate_all_hooks(category='forward_hook')
activate_module_hooks
and deactivate_module_hooks
accept an unpacked, variable-length
array of modules to filter by. These methods can take hook_types
and category
kwargs,
as well, allowing filtering by module, hook function and category in the same call.
The hook_types
kwarg accepts a variable-length array of functions. Only hooks
corresponding to one of the passed functions will be activated.
The category
kwarg accepts the following options:
- "all"
- "forward_hook"
- "forward_pre_hook"
- "backward_hook"
TackleBox additionally provides python contexts that enable activation and deactivation of module hooks with a single line of code:
with hookmngr.hook_all_context():
# all hooks active
...
# all hooks inactive
with hookmngr.hook_module_context(my_module, other_module, *more_modules):
# hooks on my_module, other_module, etc. are active
...
# hooks on my_module, other_module, etc. are inactive
Both the above context methods accept the same kwargs for filtering as the activate and deactivate
methods (ie. hook_types
and category
).
When a new hook function is registered on a module, the function is wrapped in
a HookFunction
object. HookFunctions contain the function to be called
at the corresponding entry point as well as a dictionary of modules that it
has been registered to, the corresponding handle for each
Each registration event yields a handle. TackleBox represents these handles as
HookHandle
objects that maintain references to the corresponding module and
HookFunction involved in the registration event. HookHandle.activate()
and HookHandle.deactivate()
may be used to activate and deactivate the
corresponding hook function on the corresponding module.
# access the HookHandle obtained from registering hook_function on my_module
hook_handle = hook_function.module_to_handle[my_module]
hook_handle.module # = my_module
hook_handle.hook_fn # = hook_function
# activate/deactivate hook_function for my_module
hook_handle.activate()
hook_handle.deactivate()
The hook manager maintains lookup tables for all HookFunctions, modules and HookHandles that have been registered. The user can access these objects using their id. We saw how module id is assigned when registering a module in the previous examples. Registered modules can be accessed with:
hookmngr.name_to_module['my_module_name']
HookFunction ids can also be assigned during registration by using the
hook_fn_name
kwarg:
hookmngr.register_forward_hook(my_forward_hook, hook_fn_name='my_forward_hook')
If no id is passed, the HookFunction will be assigned an id using repr(function)
to convert the passed function to a string.
Registered HookFunctions can be accessed by id:
hookmngr.name_to_hookfn['my_forward_hook']
When hook functions are registered to modules by the hook manager, the resulting
HookHandle is given an id of the form my_hook[my_module]
where my_hook
is the
id of the HookFunction and my_module
is the id of the module registered to.
Specific handles can be accessed using this id:
# access the HookHandle obtained from registering my_forward_hook on my_module
hookmngr.name_to_hookhandle['my_forward_hook[my_module]']
This lookup provides an easy way to visualize all currently registered hook across all modules.
Removing a hook will deactivate it then purge it from all records maintained by the hook manager. This means that later calls to activate hooks will be unable to reactivate it. Hooks can be removed by hook function, module or hook handle (a module, hook function pair):
hookmngr.remove_hook_function(my_forward_hook)
hookmngr.remove_module_by_name('my_module')
hookmngr.remove_hook_by_name('my_forward_hook[my_module]')
This lets us remove registered hooks with varying degrees of selectivity, using a single line code.
Using the native PyTorch module hook registration, hook removal requires iteration over maintained handles:
my_handle.remove()
other_handle.remove()
...
With TackleBox you need not worry about module handles. Register and remove hooks to and from groups of modules all at once, filtering the set of active hooks at any point during experimentation. This is the power of TackleBox.
For further reference, see Hook Management - part 1.ipynb and Hook Management - part 2.ipynb. Before running, install the notebooks' dependencies with
pip install -r requirements.txt
You can also checkout the website for video walkthroughs.