diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 2e9d8210..90b9abcc 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -374,7 +374,7 @@ def __init__( fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation function used to fuse the domains. """ - super().__init__(domain_modules, workspace_dim, fusion_activation_fn) + super().__init__(domain_modules, workspace_dim) self.gw_encoders = nn.ModuleDict(gw_encoders) """The module's encoders""" @@ -382,6 +382,9 @@ def __init__( self.gw_decoders = nn.ModuleDict(gw_decoders) """The module's decoders""" + self.fusion_activation_fn = fusion_activation_fn + """Activation function used to fuse the domains.""" + def fuse( self, x: LatentsDomainGroupT,