-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Use get_name in forward hook #393
base: master
Are you sure you want to change the base?
Conversation
get_name
in forward hook
get_name
in forward hook
Codecov Report
@@ Coverage Diff @@
## master #393 +/- ##
==========================================
- Coverage 85.49% 82.63% -2.86%
==========================================
Files 86 86
Lines 6514 6520 +6
==========================================
- Hits 5569 5388 -181
- Misses 945 1132 +187
Continue to review full report at Codecov.
|
b1a4be4
to
9e6a581
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if module.modules will work, as it will give you modules recursively. https://discuss.pytorch.org/t/module-children-vs-module-modules/4551
@@ -215,9 +224,9 @@ def register_module(self, module): | |||
|
|||
for name, submodule in module.named_modules(): | |||
assert submodule not in self.module_set, f"Don't register module={module} twice" | |||
submodule._module_name = name | |||
Hook._add_module_name(submodule, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self. ?
@@ -197,6 +198,14 @@ def register_hook(self, module): | |||
# for compatibility with ZCC patches which call this | |||
self.register_module(module) | |||
|
|||
@staticmethod | |||
def _add_module_name(module, module_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why static?
Description of bug:
DataParallelCriterion
orDataParallel
The problem with the above line is that the
_module_name
attribute is attached to theparallel_custom_loss_module
and not the nestedcustom_loss_module
.We should instead simply use:
module._get_name()
in theforward_hook
Style and formatting:
I have run
pre-commit install
to ensure that auto-formatting happens with every commit.Issue number, if available
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.