From 975a6e05278d6975942e13188260358217f1ffbc Mon Sep 17 00:00:00 2001 From: HugoChateauLaurent Date: Tue, 8 Oct 2024 14:09:59 +0000 Subject: [PATCH] Fix attribute error --- shimmer/modules/gw_module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 2e9d8210..90b9abcc 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -374,7 +374,7 @@ def __init__( fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation function used to fuse the domains. """ - super().__init__(domain_modules, workspace_dim, fusion_activation_fn) + super().__init__(domain_modules, workspace_dim) self.gw_encoders = nn.ModuleDict(gw_encoders) """The module's encoders""" @@ -382,6 +382,9 @@ def __init__( self.gw_decoders = nn.ModuleDict(gw_decoders) """The module's decoders""" + self.fusion_activation_fn = fusion_activation_fn + """Activation function used to fuse the domains.""" + def fuse( self, x: LatentsDomainGroupT,