diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index 8833ed32..9536fc46 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -93,19 +93,24 @@ def decode(self, z: torch.Tensor) -> Any: """ raise NotImplementedError - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: """ Generic loss computation the modality. Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput`: LossOuput with training loss and additional metrics. """ raise NotImplementedError - def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_dcy_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: """ Computes the loss for a demi-cycle. Override if the demi-cycle loss is different that the generic loss. @@ -113,12 +118,15 @@ def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutp Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput`: LossOuput with training loss and additional metrics. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) - def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_cy_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: """ Computes the loss for a cycle. Override if the cycle loss is different that the generic loss. @@ -126,12 +134,15 @@ def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput`: LossOuput with training loss and additional metrics. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) - def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_tr_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: """ Computes the loss for a translation. Override if the translation loss is different that the generic loss. @@ -139,13 +150,14 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput`: LossOuput with training loss and additional metrics. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) def compute_broadcast_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput: """ Computes the loss for a broadcast (fusion). Override if the broadcast loss is @@ -154,7 +166,12 @@ def compute_broadcast_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput`: LossOuput with training loss and additional metrics. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) + + +class End2EndDomainModule(DomainModule): + pass diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 8def76c5..f74657a0 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import OneCycleLR from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType -from shimmer.modules.domain import DomainModule +from shimmer.modules.domain import DomainModule, End2EndDomainModule from shimmer.modules.gw_module import ( GWModule, GWModuleBase, @@ -482,7 +482,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso domain_latents = self.encode_domains(batch) batch_size = groups_batch_size(domain_latents) - loss_output = self.loss_mod.step(domain_latents, mode) + loss_output = self.loss_mod.step(batch, domain_latents, mode) for name, metric in loss_output.all.items(): self.log( @@ -572,6 +572,10 @@ def freeze_domain_modules( The output is casted as `dict[str, DomainModule]` type for better auto-completion, but is actually a torch `ModuleDict`. + .. note:: + Instances of `End2EndDomainModule` are not frozen as they should be trained + alongside the GW. + Args: domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze @@ -580,7 +584,8 @@ def freeze_domain_modules( """ for mod in domain_mods.values(): - mod.freeze() + if not isinstance(mod, End2EndDomainModule): + mod.freeze() # Cast for better auto-completion at the expense of ModuleDict return cast(dict[str, DomainModule], ModuleDict(domain_mods)) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 08273a93..9e650d20 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -13,7 +13,7 @@ GWModuleBayesian, ) from shimmer.modules.selection import SelectionBase -from shimmer.types import LatentsDomainGroupsT, ModelModeT +from shimmer.types import LatentsDomainGroupsT, ModelModeT, RawDomainGroupsT class GWLossesBase(torch.nn.Module, ABC): @@ -26,6 +26,7 @@ class GWLossesBase(torch.nn.Module, ABC): @abstractmethod def step( self, + raw_data: RawDomainGroupsT, domain_latents: LatentsDomainGroupsT, mode: ModelModeT, ) -> LossOutput: @@ -33,6 +34,7 @@ def step( Computes the losses. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents (`LatentsDomainGroupsT`): All latent groups mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode Returns: @@ -46,6 +48,7 @@ def demi_cycle_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the demi-cycle loss. @@ -62,6 +65,7 @@ def demi_cycle_loss( domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -77,7 +81,11 @@ def demi_cycle_loss( x_recons = gw_mod.decode( gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name} )[domain_name] - loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name]) + loss_output = domain_mod.compute_dcy_loss( + x_recons, + latents[domain_name], + raw_data[domains][domain_name], + ) losses[f"demi_cycle_{domain_name}"] = loss_output.loss metrics.update( {f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()} @@ -92,6 +100,7 @@ def cycle_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the cycle loss. @@ -109,6 +118,7 @@ def cycle_loss( selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -137,6 +147,7 @@ def cycle_loss( loss_output = domain_mod.compute_cy_loss( x_recons[domain_name_source], latents_source[domain_name_source], + raw_data[domains_source][domain_name_source], ) metrics.update( {f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()} @@ -152,6 +163,7 @@ def translation_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the translation loss. @@ -169,6 +181,7 @@ def translation_loss( gw_mod (`GWModuleBase`): The GWModule to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -199,6 +212,7 @@ def translation_loss( loss_output = mod.compute_tr_loss( prediction, latents[domain_name_target], + raw_data[domains][domain_name_target], ) losses[f"translation_{loss_name}"] = loss_output.loss metrics.update( @@ -388,7 +402,7 @@ def __init__( self.contrastive_fn = contrastive_fn def demi_cycle_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the demi-cycle loss. @@ -397,16 +411,17 @@ def demi_cycle_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return demi_cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def cycle_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the cycle loss. @@ -415,16 +430,17 @@ def cycle_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def translation_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the translation loss. @@ -433,12 +449,13 @@ def translation_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return translation_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def contrastive_loss( @@ -458,7 +475,10 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Computes and returns the losses @@ -470,6 +490,7 @@ def step( - Contrastive metrics (see `GWLosses.contrastive_loss`) Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents (`LatentsDomainGroupsT`): All latent groups mode (`ModelModeT`): model mode Returns: @@ -477,9 +498,9 @@ def step( """ metrics: dict[str, torch.Tensor] = {} - metrics.update(self.demi_cycle_loss(domain_latents)) - metrics.update(self.cycle_loss(domain_latents)) - metrics.update(self.translation_loss(domain_latents)) + metrics.update(self.demi_cycle_loss(domain_latents, raw_data)) + metrics.update(self.cycle_loss(domain_latents, raw_data)) + metrics.update(self.translation_loss(domain_latents, raw_data)) metrics.update(self.contrastive_loss(domain_latents)) loss = torch.stack( @@ -516,6 +537,7 @@ def broadcast_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes broadcast loss including demi-cycle, cycle, and translation losses. @@ -525,6 +547,7 @@ def broadcast_loss( selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains: The latent domain representations. + raw_data (`RawDomainGroupsT`): raw input data Returns: A dictionary with the total loss and additional metrics. @@ -565,7 +588,9 @@ def broadcast_loss( if domain not in group_domains: # if we don't have ground truth continue ground_truth = latents[domain] - loss_output = domain_mods[domain].compute_loss(pred, ground_truth) + loss_output = domain_mods[domain].compute_loss( + pred, ground_truth, raw_data[group_domains][domain] + ) loss_label = f"from_{selected_group_label}_to_{domain}" losses[loss_label + "_loss"] = loss_output.loss metrics.update( @@ -602,7 +627,9 @@ def broadcast_loss( for domain in selected_latents: re_ground_truth = latents[domain] re_loss_output = domain_mods[domain].compute_loss( - re_decoded_latents[domain], re_ground_truth + re_decoded_latents[domain], + re_ground_truth, + raw_data[group_domains][domain], ) loss_label = ( f"from_{selected_group_label}_" @@ -710,19 +737,23 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Performs a step of loss computation. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents: Latent representations for all domains. mode: The mode in which the model is currently operating. @@ -733,7 +764,7 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents)) + metrics.update(self.broadcast_loss(domain_latents, raw_data)) loss = torch.stack( [ @@ -824,19 +855,23 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Performs a step of loss computation. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents: Latent representations for all domains. mode: The mode in which the model is currently operating. @@ -847,7 +882,7 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents)) + metrics.update(self.broadcast_loss(domain_latents, raw_data)) loss = torch.stack( [