Skip to content

Commit

Permalink
Put type definition to new file to avoid cycling imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 1, 2024
1 parent 85ba6fc commit 7110736
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 53 deletions.
22 changes: 18 additions & 4 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@
from shimmer.modules.losses import (
GWLosses,
GWLossesBase,
LatentsDomainGroupT,
LatentsT,
LossCoefs,
VariationalGWLosses,
VariationalLossCoefs,
)
from shimmer.types import (
LatentsDomainGroupDT,
LatentsDomainGroupsDT,
LatentsDomainGroupsT,
LatentsDomainGroupT,
RawDomainGroupDT,
RawDomainGroupsDT,
RawDomainGroupsT,
RawDomainGroupT,
)
from shimmer.version import __version__

__all__ = [
Expand All @@ -60,10 +68,8 @@
"ContrastiveLossWithUncertainty",
"contrastive_loss",
"contrastive_loss_with_uncertainty",
"LatentsT",
"LossCoefs",
"VariationalLossCoefs",
"LatentsDomainGroupT",
"GWLosses",
"GWLossesBase",
"VariationalGWLosses",
Expand All @@ -74,4 +80,12 @@
"GWPredictions",
"pretrained_global_workspace",
"RepeatedDataset",
"LatentsDomainGroupDT",
"LatentsDomainGroupsDT",
"LatentsDomainGroupsT",
"LatentsDomainGroupT",
"RawDomainGroupDT",
"RawDomainGroupsDT",
"RawDomainGroupsT",
"RawDomainGroupT",
]
4 changes: 2 additions & 2 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from shimmer.modules.losses import (
GWLosses,
GWLossesBase,
LatentsDomainGroupsT,
LatentsDomainGroupT,
LatentsT,
LossCoefs,
VariationalGWLosses,
VariationalLossCoefs,
Expand All @@ -55,7 +55,7 @@
"ContrastiveLossWithUncertainty",
"contrastive_loss",
"contrastive_loss_with_uncertainty",
"LatentsT",
"LatentsDomainGroupsT",
"LossCoefs",
"VariationalLossCoefs",
"LatentsDomainGroupT",
Expand Down
8 changes: 5 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
GWLosses,
GWLossesBase,
GWLossesFusion,
LossCoefs,
VariationalGWLosses,
VariationalLossCoefs,
)
from shimmer.types import (
LatentsDomainGroupsDT,
LatentsDomainGroupsT,
LatentsDomainGroupT,
LossCoefs,
RawDomainGroupsDT,
RawDomainGroupsT,
RawDomainGroupT,
VariationalGWLosses,
VariationalLossCoefs,
)


Expand Down
16 changes: 15 additions & 1 deletion shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch import nn

from shimmer.modules.domain import DomainModule
from shimmer.modules.losses import LatentsDomainGroupT
from shimmer.modules.vae import reparameterize
from shimmer.types import LatentsDomainGroupT


def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]:
Expand Down Expand Up @@ -381,6 +381,17 @@ def __init__(
decoder_hidden_dim: int,
decoder_n_layers: int,
) -> None:
"""
Initialized the interface.
Args:
domain_module (`DomainModule`): Domain module to link.
workspace_dim (`int`): dimension of the GW.
encoder_hidden_dim (`int`): `hidden_dim` used for `GWEncoder`.
encoder_n_layers (`int`): `n_layers` used for `GWEncoder`.
decoder_hidden_dim (`int`): `hidden_dim` used for `GWDecoder`.
decoder_n_layers (`int`): `n_layers` used for `GWDecoder`.
"""
super().__init__(domain_module, workspace_dim)

self.encoder = GWEncoder(
Expand All @@ -389,12 +400,15 @@ def __init__(
workspace_dim,
encoder_n_layers,
)
"""The interface encoder"""

self.decoder = GWDecoder(
workspace_dim,
decoder_hidden_dim,
domain_module.latent_dim,
decoder_n_layers,
)
"""The interface decoder"""

def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
Expand Down
45 changes: 2 additions & 43 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import TypedDict

import torch
import torch.nn.functional as F
Expand All @@ -9,47 +8,7 @@
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.gw_module import GWModule, GWModuleBase, VariationalGWModule
from shimmer.modules.vae import kl_divergence_loss

RawDomainGroupT = Mapping[str, Any]
"""Matched raw unimodal data from multiple domains.
Keys of the mapping are domains names."""

RawDomainGroupDT = dict[str, Any]
"""Matched raw unimodal data from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `RawDomainGroupT` used in method's outputs."""

LatentsDomainGroupT = Mapping[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
Keys of the mapping are domains names."""

LatentsDomainGroupDT = dict[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `LatentsDomainGroupT` used in method's outputs."""

LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT]
"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""

LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT]
"""Mapping of `LatentsDomainGroupDT`.
Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `LatentsDomainGroupsT` used in method's outputs."""

RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""

RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `RawDomainGroupsT` used in method's outputs."""
from shimmer.types import LatentsDomainGroupsT


class GWLossesBase(torch.nn.Module, ABC):
Expand Down
45 changes: 45 additions & 0 deletions shimmer/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any

import torch

RawDomainGroupT = Mapping[str, Any]
"""Matched raw unimodal data from multiple domains.
Keys of the mapping are domains names."""

RawDomainGroupDT = dict[str, Any]
"""Matched raw unimodal data from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `RawDomainGroupT` used in method's outputs."""

LatentsDomainGroupT = Mapping[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
Keys of the mapping are domains names."""

LatentsDomainGroupDT = dict[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `LatentsDomainGroupT` used in method's outputs."""

LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT]
"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""

LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT]
"""Mapping of `LatentsDomainGroupDT`.
Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `LatentsDomainGroupsT` used in method's outputs."""

RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""

RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `RawDomainGroupsT` used in method's outputs."""

0 comments on commit 7110736

Please sign in to comment.