From cd6c43a427dcd8f456f180d964ae2508f1a07700 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Thu, 19 Oct 2023 09:09:55 +0000 Subject: [PATCH] add tanh back in GW encoder --- shimmer/modules/gw_module.py | 40 ++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index a65e8bf2..595bb55a 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -14,7 +14,7 @@ def get_n_layers(n_layers: int, hidden_dim: int): return layers -class GWEncoder(nn.Sequential): +class GWDecoder(nn.Sequential): def __init__( self, in_dim: int, @@ -28,7 +28,7 @@ def __init__( self.n_layers = n_layers - super(GWEncoder, self).__init__( + super().__init__( nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), *get_n_layers(n_layers, self.hidden_dim), @@ -36,7 +36,7 @@ def __init__( ) -class VariationalGWEncoder(nn.Module): +class GWEncoder(GWDecoder): def __init__( self, in_dim: int, @@ -44,7 +44,21 @@ def __init__( out_dim: int, n_layers: int, ): - super(VariationalGWEncoder, self).__init__() + super().__init__(in_dim, hidden_dim, out_dim, n_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x) + + +class VariationalGWDecoder(nn.Module): + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + n_layers: int, + ): + super().__init__() self.in_dim = in_dim self.hidden_dim = hidden_dim @@ -63,6 +77,20 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return self.layers(x), self.uncertainty_level.expand(x.size(0), -1) +class VariationalGWEncoder(VariationalGWDecoder): + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + n_layers: int, + ): + super().__init__(in_dim, hidden_dim, out_dim, n_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x) + + class GWModule(nn.Module): domain_descr: Mapping[str, DomainDescription] @@ -166,7 +194,7 @@ def __init__( ) self.decoders = nn.ModuleDict( { - domain: GWEncoder( + domain: GWDecoder( self.latent_dim, self.decoder_hidden_dim[domain], self.input_dim[domain], @@ -246,7 +274,7 @@ def __init__( ) self.decoders = nn.ModuleDict( { - domain: GWEncoder( + domain: VariationalGWDecoder( self.latent_dim, self.decoder_hidden_dim[domain], self.input_dim[domain],