Skip to content

Commit

Permalink
Add raw_data as parameters for compute_loss and do not freeze End2End…
Browse files Browse the repository at this point in the history
…DomainModules
  • Loading branch information
bdvllrs committed Sep 19, 2024
1 parent 3e78b27 commit b092e1e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 34 deletions.
35 changes: 26 additions & 9 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,59 +93,71 @@ def decode(self, z: torch.Tensor) -> Any:
"""
raise NotImplementedError

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Generic loss computation the modality.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
"""
raise NotImplementedError

def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_dcy_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
different that the generic loss.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_cy_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Computes the loss for a cycle. Override if the cycle loss is
different that the generic loss.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_tr_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Computes the loss for a translation. Override if the translation loss is
different that the generic loss.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_broadcast_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Computes the loss for a broadcast (fusion). Override if the broadcast loss is
Expand All @@ -154,7 +166,12 @@ def compute_broadcast_loss(
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)


class End2EndDomainModule(DomainModule):
pass
11 changes: 8 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule
from shimmer.modules.domain import DomainModule, End2EndDomainModule
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
Expand Down Expand Up @@ -482,7 +482,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso
domain_latents = self.encode_domains(batch)
batch_size = groups_batch_size(domain_latents)

loss_output = self.loss_mod.step(domain_latents, mode)
loss_output = self.loss_mod.step(batch, domain_latents, mode)

for name, metric in loss_output.all.items():
self.log(
Expand Down Expand Up @@ -572,6 +572,10 @@ def freeze_domain_modules(
The output is casted as `dict[str, DomainModule]` type for better
auto-completion, but is actually a torch `ModuleDict`.
.. note::
Instances of `End2EndDomainModule` are not frozen as they should be trained
alongside the GW.
Args:
domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
Expand All @@ -580,7 +584,8 @@ def freeze_domain_modules(
"""

for mod in domain_mods.values():
mod.freeze()
if not isinstance(mod, End2EndDomainModule):
mod.freeze()
# Cast for better auto-completion at the expense of ModuleDict
return cast(dict[str, DomainModule], ModuleDict(domain_mods))

Expand Down
Loading

0 comments on commit b092e1e

Please sign in to comment.