Skip to content

Commit

Permalink
Custom fusion_activation_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoChateauLaurent committed Oct 4, 2024
1 parent 5b69f6d commit 987f811
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
17 changes: 13 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from enum import Enum, auto
from pathlib import Path
from typing import Any, Generic, TypedDict, TypeVar, cast
Expand Down Expand Up @@ -663,6 +663,7 @@ def __init__(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes a Global Workspace
Expand All @@ -689,10 +690,12 @@ def __init__(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
Expand Down Expand Up @@ -736,6 +739,7 @@ def __init__(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -764,9 +768,11 @@ def __init__(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
Expand Down Expand Up @@ -800,6 +806,7 @@ def pretrained_global_workspace(
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
**kwargs,
) -> GlobalWorkspace2Domains:
"""
Expand All @@ -823,6 +830,8 @@ def pretrained_global_workspace(
contrastive losses.
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
**kwargs: additional arguments to pass to
`GlobalWorkspace.load_from_checkpoint`.
Expand All @@ -833,7 +842,7 @@ def pretrained_global_workspace(
`TypeError`: if loaded type is not `GlobalWorkspace`.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn)
selection_mod = SingleDomainSelection()
loss_mod = GWLosses2Domains(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
Expand Down
21 changes: 13 additions & 8 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from typing import TypedDict

import torch
Expand Down Expand Up @@ -180,11 +180,7 @@ def __init__(


class GWEncoder(GWDecoder):
"""
An Encoder network used in GWModules.
This is similar to the decoder, but adds a tanh non-linearity at the end.
"""
"""An Encoder network used in GWModules."""

def __init__(
self,
Expand Down Expand Up @@ -246,6 +242,7 @@ def __init__(
self,
domain_mods: Mapping[str, DomainModule],
workspace_dim: int,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
*args,
**kwargs,
) -> None:
Expand All @@ -255,6 +252,8 @@ def __init__(
Args:
domain_modules (`Mapping[str, DomainModule]`): the domain modules.
workspace_dim (`int`): dimension of the GW.
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function used to fuse the domains.
"""
super().__init__()

Expand All @@ -264,6 +263,9 @@ def __init__(
self.workspace_dim = workspace_dim
"""Dimension of the GW"""

self.fusion_activation_fn = fusion_activation_fn
"""Activation function used to fuse the domains"""

@abstractmethod
def fuse(
self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
Expand Down Expand Up @@ -361,6 +363,7 @@ def __init__(
workspace_dim: int,
gw_encoders: Mapping[str, nn.Module],
gw_decoders: Mapping[str, nn.Module],
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
) -> None:
"""
Initializes the GWModule.
Expand All @@ -374,8 +377,10 @@ def __init__(
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a an torch.nn.Module class that decodes a
GW representation to a unimodal latent representation.
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function used to fuse the domains.
"""
super().__init__(domain_modules, workspace_dim)
super().__init__(domain_modules, workspace_dim, fusion_activation_fn)

self.gw_encoders = nn.ModuleDict(gw_encoders)
"""The module's encoders"""
Expand All @@ -398,7 +403,7 @@ def fuse(
Returns:
`torch.Tensor`: The merged representation.
"""
return torch.tanh(
return self.fusion_activation_fn(
torch.sum(
torch.stack(
[
Expand Down

0 comments on commit 987f811

Please sign in to comment.