Skip to content

Commit

Permalink
Systematic coeffs fusion (#48)
Browse files Browse the repository at this point in the history
* draft version

* add typing for metrics and losses dicts

* take out sample_scaling_factors

* generate permutations function

* took out gwmodulefusion (gwmodule already does fusion)

* full broadcast function with logging

* reformat for ruff (super ugly now)

* first test for new broadcast

* fixed errors

* test works + mypy type checks

* fixed to run properly with simple shapes

* random attention returns scores on right device now

* partitions -- permutations

* partitions function fixes

* èdocstring han

* Remove vim's swp files

* Reformat imports

* Type generate_partitions function

* Do not use forward

* Add var for sum(partition) and len(partition) makes code more understandable

* Use strict=True

* Reformatting and ruff fixes

* Add selection_temperature as a GlobalWorkspaceFusion parameter

* Log all metrics

* Add loss coefs to fusion model

* move tanh to postfusion

* Add loss_coefs to broadcast tests

* uniform output shape for randomselection and singledomainselection

* ruff fixes

---------

Co-authored-by: bdvllrs <[email protected]>
  • Loading branch information
RolandBERTINJOHANNET and bdvllrs authored Apr 11, 2024
1 parent f65d962 commit fc34d38
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 339 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ dmypy.json
# Cython debug symbols
cython_debug/

# vim
*.swp

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
Expand Down
2 changes: 2 additions & 0 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GWModuleWithUncertainty,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLossesBase,
GWLossesWithUncertainty,
Expand Down Expand Up @@ -82,6 +83,7 @@
"contrastive_loss",
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses",
"GWLossesWithUncertainty",
Expand Down
2 changes: 2 additions & 0 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GWModuleWithUncertainty,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLossesBase,
GWLossesWithUncertainty,
Expand Down Expand Up @@ -73,6 +74,7 @@
"contrastive_loss_with_uncertainty",
"ContrastiveLossWithUncertainty",
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses",
"GWLossesWithUncertainty",
Expand Down
22 changes: 16 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
GWModuleFusion,
GWModuleWithUncertainty,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLossesBase,
GWLossesFusion,
GWLossesWithUncertainty,
LossCoefs,
)
from shimmer.modules.selection import SelectionBase, SingleDomainSelection
from shimmer.modules.selection import (
RandomSelection,
SelectionBase,
SingleDomainSelection,
)
from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations
from shimmer.types import (
LatentsDomainGroupsDT,
Expand Down Expand Up @@ -651,6 +655,8 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
selection_temperature: float = 0.2,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -671,6 +677,9 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -681,16 +690,17 @@ def __init__(
contrastive losses.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModuleFusion(domain_mods, workspace_dim, gw_encoders, gw_decoders)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)

# TODO: use the correction selection module
selection_mod = SingleDomainSelection()
loss_mod = GWLossesFusion(gw_mod, selection_mod, domain_mods, contrastive_loss)
selection_mod = RandomSelection(selection_temperature)
loss_mod = GWLossesFusion(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
)

super().__init__(
gw_mod,
Expand Down
117 changes: 14 additions & 103 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
super().__init__(in_dim, hidden_dim, out_dim, n_layers)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(super().forward(input))
return super().forward(input)


class GWEncoderLinear(nn.Linear):
Expand Down Expand Up @@ -252,14 +252,16 @@ def fuse(
Returns:
`torch.Tensor`: The merged representation.
"""
return torch.sum(
torch.stack(
[
selection_scores[domain].unsqueeze(1) * x[domain]
for domain in selection_scores
]
),
dim=0,
return torch.tanh(
torch.sum(
torch.stack(
[
selection_scores[domain].unsqueeze(1) * x[domain]
for domain in selection_scores
]
),
dim=0,
)
)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
Expand Down Expand Up @@ -364,7 +366,9 @@ def _fuse_and_scores(
coef = final_scores.sum(dim=0)
final_scores = final_scores / coef

return torch.sum(final_scores * torch.stack(domains), dim=0), final_scores
return torch.tanh(
torch.sum(final_scores * torch.stack(domains), dim=0)
), final_scores

def fuse(
self,
Expand Down Expand Up @@ -406,96 +410,3 @@ def fuse(
`torch.Tensor`: The merged representation.
"""
return self._fuse_and_scores(x, selection_scores)[0]


class GWModuleFusion(GWModuleBase):
"""
GWModule used for fusion.
"""

def __init__(
self,
domain_modules: Mapping[str, DomainModule],
workspace_dim: int,
gw_encoders: Mapping[str, nn.Module],
gw_decoders: Mapping[str, nn.Module],
) -> None:
"""
Initializes the GWModule Fusion.
Args:
domain_modules (`Mapping[str, DomainModule]`): the domain modules.
workspace_dim (`int`): dimension of the GW.
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a an torch.nn.Module class that encodes a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a an torch.nn.Module class that decodes a
GW representation to a unimodal latent representation.
"""
super().__init__(domain_modules, workspace_dim)

self.gw_encoders = nn.ModuleDict(gw_encoders)
"""The module's encoders"""

self.gw_decoders = nn.ModuleDict(gw_decoders)
"""The module's decoders"""

def fuse(
self,
x: LatentsDomainGroupT,
selection_scores: Mapping[str, torch.Tensor],
) -> torch.Tensor:
"""
Merge function used to combine domains.
Args:
x (`LatentsDomainGroupT`): the group of latent representation.
selection_score (`Mapping[str, torch.Tensor]`): attention scores to
use to encode the reprensetation.
Returns:
`torch.Tensor`: The merged representation.
"""
return torch.sum(
torch.stack(
[
selection_scores[domain].unsqueeze(1) * x[domain]
for domain in selection_scores
]
),
dim=0,
)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
"""
Encode the unimodal latent representation `x` into the pre-fusion GW
representations.
Args:
x (`LatentsDomainGroupT`): the group of latent representation.
Returns:
`torch.Tensor`: encoded and fused GW representation.
"""
return {
domain_name: self.gw_encoders[domain_name](domain)
for domain_name, domain in x.items()
}

def decode(
self, z: torch.Tensor, domains: Iterable[str] | None = None
) -> LatentsDomainGroupDT:
"""
Decodes a GW representation to multiple domains.
Args:
z (`torch.Tensor`): the GW representation
domains (`Iterable[str] | None`): the domains to decode to. Defaults to
use keys in `gw_interfaces` (all domains).
Returns:
`LatentsDomainGroupDT`: decoded unimodal representation
"""
return {
domain: self.gw_decoders[domain](z)
for domain in domains or self.gw_decoders.keys()
}
Loading

0 comments on commit fc34d38

Please sign in to comment.