Skip to content

Commit

Permalink
Add Linear Decoder layer
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 6, 2024
1 parent d77608f commit 3a3f1e0
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]:


class GWDecoder(nn.Sequential):
"""A Decoder network used in GWInterfaces."""
"""A Decoder network for GWModules."""

def __init__(
self,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(


class GWEncoder(GWDecoder):
"""An Encoder network used in GWInterfaces.
"""An Encoder network used in GWModules.
This is similar to the decoder, but adds a tanh non-linearity at the end.
"""
Expand All @@ -95,6 +95,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(super().forward(input))


class GWEncoderLinear(nn.Linear):
"""A linear Encoder network used in GWModules."""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(super().forward(input))


class VariationalGWEncoder(nn.Module):
"""A Variational flavor of encoder network used in GWInterfaces."""

Expand Down

0 comments on commit 3a3f1e0

Please sign in to comment.