Skip to content

Commit

Permalink
add missing super call
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 19, 2023
1 parent cd6c43a commit 1ef896b
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
super().__init__(in_dim, hidden_dim, out_dim, n_layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x)
return torch.tanh(super().forward(x))


class VariationalGWDecoder(nn.Module):
class VariationalGWEncoder(nn.Module):
def __init__(
self,
in_dim: int,
Expand All @@ -70,27 +70,14 @@ def __init__(
nn.ReLU(),
*get_n_layers(n_layers, self.hidden_dim),
nn.Linear(self.hidden_dim, self.out_dim),
nn.Tanh(),
)
self.uncertainty_level = nn.Parameter(torch.full((self.out_dim,), 3.0))

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)


class VariationalGWEncoder(VariationalGWDecoder):
def __init__(
self,
in_dim: int,
hidden_dim: int,
out_dim: int,
n_layers: int,
):
super().__init__(in_dim, hidden_dim, out_dim, n_layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x)


class GWModule(nn.Module):
domain_descr: Mapping[str, DomainDescription]

Expand Down Expand Up @@ -274,7 +261,7 @@ def __init__(
)
self.decoders = nn.ModuleDict(
{
domain: VariationalGWDecoder(
domain: GWDecoder(
self.latent_dim,
self.decoder_hidden_dim[domain],
self.input_dim[domain],
Expand Down

0 comments on commit 1ef896b

Please sign in to comment.