Skip to content

Commit

Permalink
Small module broadcast utility.
Browse files Browse the repository at this point in the history
  • Loading branch information
MillionIntegrals committed Apr 7, 2019
1 parent c013955 commit 2547da7
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vel/util/module_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def apply_leaf(module, f):
apply_leaf(l, f)


def module_apply_broadcast(m, broadcast_fn, args, kwargs):
if hasattr(m, broadcast_fn):
getattr(m, broadcast_fn)(*args, **kwargs)


def module_broadcast(m, broadcast_fn, *args, **kwargs):
""" Call given function in all submodules with given parameters """
apply_leaf(m, lambda x: module_apply_broadcast(x, broadcast_fn, args, kwargs))


def set_train_mode(module):
# Only fix ones which we don't want to "train"
if hasattr(module, 'running_mean') and (getattr(module, 'bn_freeze', False) or not getattr(module, 'trainable', True)):
Expand Down

0 comments on commit 2547da7

Please sign in to comment.