From 3a3f1e08a823252514505478047ed2811247fbc6 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Wed, 6 Mar 2024 11:06:44 +0000 Subject: [PATCH] Add Linear Decoder layer --- shimmer/modules/gw_module.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 74e989d5..0b653976 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -27,7 +27,7 @@ def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]: class GWDecoder(nn.Sequential): - """A Decoder network used in GWInterfaces.""" + """A Decoder network for GWModules.""" def __init__( self, @@ -68,7 +68,7 @@ def __init__( class GWEncoder(GWDecoder): - """An Encoder network used in GWInterfaces. + """An Encoder network used in GWModules. This is similar to the decoder, but adds a tanh non-linearity at the end. """ @@ -95,6 +95,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.tanh(super().forward(input)) +class GWEncoderLinear(nn.Linear): + """A linear Encoder network used in GWModules.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.tanh(super().forward(input)) + + class VariationalGWEncoder(nn.Module): """A Variational flavor of encoder network used in GWInterfaces."""