diff --git a/vel/util/module_util.py b/vel/util/module_util.py index e40dcd7f..e08e9eda 100644 --- a/vel/util/module_util.py +++ b/vel/util/module_util.py @@ -26,6 +26,16 @@ def apply_leaf(module, f): apply_leaf(l, f) +def module_apply_broadcast(m, broadcast_fn, args, kwargs): + if hasattr(m, broadcast_fn): + getattr(m, broadcast_fn)(*args, **kwargs) + + +def module_broadcast(m, broadcast_fn, *args, **kwargs): + """ Call given function in all submodules with given parameters """ + apply_leaf(m, lambda x: module_apply_broadcast(x, broadcast_fn, args, kwargs)) + + def set_train_mode(module): # Only fix ones which we don't want to "train" if hasattr(module, 'running_mean') and (getattr(module, 'bn_freeze', False) or not getattr(module, 'trainable', True)):