Skip to content

Commit

Permalink
get_name
Browse files Browse the repository at this point in the history
  • Loading branch information
NihalHarish committed Oct 30, 2020
1 parent be862bd commit 9e6a581
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def register_hook(self, module):
# for compatibility with ZCC patches which call this
self.register_module(module)

@staticmethod
def _add_module_name(module, module_name):
if isinstance(module, torch.nn.parallel.data_parallel.DataParallel):
module.module._module_name = module_name
else:
module._module_name = module_name
return module

def register_module(self, module):
"""
This function registers the forward hook. If user wants to register the hook
Expand All @@ -215,9 +223,9 @@ def register_module(self, module):

for name, submodule in module.named_modules():
assert submodule not in self.module_set, f"Don't register module={module} twice"
submodule._module_name = name
Hook._add_module_name(submodule, name)
self.module_set.add(submodule)
module._module_name = module._get_name()
Hook._add_module_name(module, module._get_name())
self.module_set.add(module)

# Use `forward_pre_hook` for the entire net
Expand Down

0 comments on commit 9e6a581

Please sign in to comment.