Skip to content

Commit

Permalink
add method for end2end domain modules to have an extra loss on inputs…
Browse files Browse the repository at this point in the history
… during GW training
  • Loading branch information
bdvllrs committed Sep 20, 2024
1 parent 9976918 commit 9811d43
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
13 changes: 12 additions & 1 deletion shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,15 @@ def compute_fused_loss(


class End2EndDomainModule(DomainModule):
pass
def compute_domain_loss(self, domain: Any) -> LossOutput | None:
"""
Compute the unimodal domain loss.
Args:
domain (`Any`): domain input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss.
"""
return None
34 changes: 33 additions & 1 deletion shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule, End2EndDomainModule
from shimmer.modules.domain import DomainModule, End2EndDomainModule, LossOutput
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
Expand Down Expand Up @@ -484,6 +484,24 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
for domains, latents in latents_domain.items()
}

def unimodal_losses(self, batch: RawDomainGroupsT) -> LossOutput | None:
metrics: dict[str, torch.Tensor] = {}
losses: list[torch.Tensor] = []
for group_domain_names, domain_group in batch.items():
if len(group_domain_names) > 1:
continue
for domain_name, domain in domain_group.items():
domain_mod = self.domain_mods[domain_name]
if isinstance(domain_mod, End2EndDomainModule):
loss = domain_mod.compute_domain_loss(domain)
if loss is not None:
for name, metric in loss.metrics.items():
metrics[f"{domain_name}/{name}"] = metric
losses.append(loss.loss)
if not len(losses):
return None
return LossOutput(loss=torch.stack(losses, dim=0).sum(), metrics=metrics)

def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT:
"""
The generic step used in `training_step`, `validation_step` and
Expand All @@ -509,6 +527,20 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT
add_dataloader_idx=False,
)

total_loss = loss_output.loss

unimodal_losses = self.unimodal_losses(batch)
if unimodal_losses is not None:
for name, metric in unimodal_losses.all.items():
self.log(
f"{mode}/{name}",
metric,
batch_size=batch_size,
add_dataloader_idx=False,
)

total_loss += unimodal_losses.loss

return loss_output.loss

def validation_step( # type: ignore
Expand Down

0 comments on commit 9811d43

Please sign in to comment.