Skip to content

Commit

Permalink
Remove GWEncoderLinear (#163)
Browse files Browse the repository at this point in the history
Removing the GWEncoderLinear class because it is not used.
  • Loading branch information
HugoChateauLaurent authored Oct 8, 2024
1 parent 81753a0 commit fd1d65c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
2 changes: 0 additions & 2 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModulePrediction,
Expand Down Expand Up @@ -77,7 +76,6 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModulePrediction",
Expand Down
2 changes: 0 additions & 2 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModulePrediction,
Expand Down Expand Up @@ -56,7 +55,6 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModulePrediction",
Expand Down
23 changes: 19 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from enum import Enum, auto
from pathlib import Path
from typing import Any, Generic, TypedDict, TypeVar, cast
Expand Down Expand Up @@ -663,6 +663,7 @@ def __init__(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes a Global Workspace
Expand All @@ -689,10 +690,14 @@ def __init__(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
Expand Down Expand Up @@ -736,6 +741,7 @@ def __init__(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -764,9 +770,13 @@ def __init__(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
Expand Down Expand Up @@ -800,6 +810,7 @@ def pretrained_global_workspace(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
**kwargs,
) -> GlobalWorkspace2Domains:
"""
Expand All @@ -823,6 +834,8 @@ def pretrained_global_workspace(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
**kwargs: additional arguments to pass to
`GlobalWorkspace.load_from_checkpoint`.
Expand All @@ -833,7 +846,9 @@ def pretrained_global_workspace(
`TypeError`: if loaded type is not `GlobalWorkspace`.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)
selection_mod = SingleDomainSelection()
loss_mod = GWLosses2Domains(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
Expand Down
23 changes: 9 additions & 14 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from typing import TypedDict

import torch
Expand Down Expand Up @@ -180,11 +180,7 @@ def __init__(


class GWEncoder(GWDecoder):
"""
An Encoder network used in GWModules.
This is similar to the decoder, but adds a tanh non-linearity at the end.
"""
"""An Encoder network used in GWModules."""

def __init__(
self,
Expand All @@ -209,13 +205,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)


class GWEncoderLinear(nn.Linear):
"""A linear Encoder network used in GWModules."""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(super().forward(input))


class GWModulePrediction(TypedDict):
"""TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""

Expand Down Expand Up @@ -368,6 +357,7 @@ def __init__(
workspace_dim: int,
gw_encoders: Mapping[str, nn.Module],
gw_decoders: Mapping[str, nn.Module],
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes the GWModule.
Expand All @@ -381,6 +371,8 @@ def __init__(
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a an torch.nn.Module class that decodes a
GW representation to a unimodal latent representation.
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function used to fuse the domains.
"""
super().__init__(domain_modules, workspace_dim)

Expand All @@ -390,6 +382,9 @@ def __init__(
self.gw_decoders = nn.ModuleDict(gw_decoders)
"""The module's decoders"""

self.fusion_activation_fn = fusion_activation_fn
"""Activation function used to fuse the domains."""

def fuse(
self,
x: LatentsDomainGroupT,
Expand All @@ -405,7 +400,7 @@ def fuse(
Returns:
`torch.Tensor`: The merged representation.
"""
return torch.tanh(
return self.fusion_activation_fn(
torch.sum(
torch.stack(
[
Expand Down

0 comments on commit fd1d65c

Please sign in to comment.