Skip to content

Commit

Permalink
add tanh back in GW encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 19, 2023
1 parent ea38579 commit cd6c43a
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_n_layers(n_layers: int, hidden_dim: int):
return layers


class GWEncoder(nn.Sequential):
class GWDecoder(nn.Sequential):
def __init__(
self,
in_dim: int,
Expand All @@ -28,23 +28,37 @@ def __init__(

self.n_layers = n_layers

super(GWEncoder, self).__init__(
super().__init__(
nn.Linear(self.in_dim, self.hidden_dim),
nn.ReLU(),
*get_n_layers(n_layers, self.hidden_dim),
nn.Linear(self.hidden_dim, self.out_dim),
)


class VariationalGWEncoder(nn.Module):
class GWEncoder(GWDecoder):
def __init__(
self,
in_dim: int,
hidden_dim: int,
out_dim: int,
n_layers: int,
):
super(VariationalGWEncoder, self).__init__()
super().__init__(in_dim, hidden_dim, out_dim, n_layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x)


class VariationalGWDecoder(nn.Module):
def __init__(
self,
in_dim: int,
hidden_dim: int,
out_dim: int,
n_layers: int,
):
super().__init__()

self.in_dim = in_dim
self.hidden_dim = hidden_dim
Expand All @@ -63,6 +77,20 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)


class VariationalGWEncoder(VariationalGWDecoder):
def __init__(
self,
in_dim: int,
hidden_dim: int,
out_dim: int,
n_layers: int,
):
super().__init__(in_dim, hidden_dim, out_dim, n_layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x)


class GWModule(nn.Module):
domain_descr: Mapping[str, DomainDescription]

Expand Down Expand Up @@ -166,7 +194,7 @@ def __init__(
)
self.decoders = nn.ModuleDict(
{
domain: GWEncoder(
domain: GWDecoder(
self.latent_dim,
self.decoder_hidden_dim[domain],
self.input_dim[domain],
Expand Down Expand Up @@ -246,7 +274,7 @@ def __init__(
)
self.decoders = nn.ModuleDict(
{
domain: GWEncoder(
domain: VariationalGWDecoder(
self.latent_dim,
self.decoder_hidden_dim[domain],
self.input_dim[domain],
Expand Down

0 comments on commit cd6c43a

Please sign in to comment.