From cd98e216593eb2edb407231c39a6fae9476df0ca Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 28 Jun 2024 17:31:22 +0200 Subject: [PATCH] Add broadcasts and cycled broadcast to GW forward (#104) * Add batch_broadcasts utils * GlobalWorkspace forwards now return broadcasts * Add forward for gw_mod * fix: circular import * Do broadcast cycles only if inverse is non empty --- docs/make.py | 1 - shimmer/__init__.py | 22 ++- shimmer/modules/__init__.py | 22 ++- shimmer/modules/global_workspace.py | 273 ++++++++++++++++------------ shimmer/modules/gw_module.py | 157 +++++++++++++++- shimmer/modules/utils.py | 163 ----------------- 6 files changed, 338 insertions(+), 300 deletions(-) delete mode 100644 shimmer/modules/utils.py diff --git a/docs/make.py b/docs/make.py index 641c3ebb..100a34e2 100644 --- a/docs/make.py +++ b/docs/make.py @@ -14,7 +14,6 @@ "shimmer.modules.contrastive_loss", "shimmer.dataset", "shimmer.modules.vae", - "shimmer.modules.utils", "shimmer.utils", "shimmer.cli.ckpt_migration", ] diff --git a/shimmer/__init__.py b/shimmer/__init__.py index ac6afc2d..c3197ac9 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -15,8 +15,11 @@ GlobalWorkspace2Domains, GlobalWorkspaceBase, GlobalWorkspaceBayesian, - GWPredictions, SchedulerArgs, + batch_broadcasts, + batch_cycles, + batch_demi_cycles, + batch_translations, pretrained_global_workspace, ) from shimmer.modules.gw_module import ( @@ -26,6 +29,11 @@ GWModule, GWModuleBase, GWModuleBayesian, + GWModulePrediction, + broadcast, + broadcast_cycles, + cycle, + translation, ) from shimmer.modules.losses import ( BroadcastLossCoefs, @@ -39,13 +47,6 @@ SelectionBase, SingleDomainSelection, ) -from shimmer.modules.utils import ( - batch_cycles, - batch_demi_cycles, - batch_translations, - cycle, - translation, -) from shimmer.types import ( LatentsDomainGroupDT, LatentsDomainGroupsDT, @@ -72,7 +73,6 @@ "RawDomainGroupT", "ModelModeT", "SchedulerArgs", - "GWPredictions", "GlobalWorkspaceBase", "GlobalWorkspace2Domains", "GlobalWorkspaceBayesian", @@ -85,6 +85,7 @@ "GWModuleBase", "GWModule", "GWModuleBayesian", + "GWModulePrediction", "ContrastiveLossType", "contrastive_loss", "ContrastiveLoss", @@ -97,6 +98,9 @@ "batch_cycles", "batch_demi_cycles", "batch_translations", + "batch_broadcasts", + "broadcast", + "broadcast_cycles", "cycle", "translation", "MIGRATION_DIR", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 00b2fa91..7d3a18ec 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -10,8 +10,11 @@ GlobalWorkspace2Domains, GlobalWorkspaceBase, GlobalWorkspaceBayesian, - GWPredictions, SchedulerArgs, + batch_broadcasts, + batch_cycles, + batch_demi_cycles, + batch_translations, pretrained_global_workspace, ) from shimmer.modules.gw_module import ( @@ -21,6 +24,11 @@ GWModule, GWModuleBase, GWModuleBayesian, + GWModulePrediction, + broadcast, + broadcast_cycles, + cycle, + translation, ) from shimmer.modules.losses import ( BroadcastLossCoefs, @@ -34,13 +42,6 @@ SelectionBase, SingleDomainSelection, ) -from shimmer.modules.utils import ( - batch_cycles, - batch_demi_cycles, - batch_translations, - cycle, - translation, -) from shimmer.modules.vae import ( VAE, VAEDecoder, @@ -52,7 +53,6 @@ __all__ = [ "SchedulerArgs", - "GWPredictions", "GlobalWorkspaceBase", "GlobalWorkspace2Domains", "GlobalWorkspaceBayesian", @@ -65,6 +65,7 @@ "GWModuleBase", "GWModule", "GWModuleBayesian", + "GWModulePrediction", "ContrastiveLossType", "ContrastiveLossBayesianType", "contrastive_loss", @@ -84,6 +85,9 @@ "batch_cycles", "batch_demi_cycles", "batch_translations", + "batch_broadcasts", + "broadcast", + "broadcast_cycles", "cycle", "translation", "RandomSelection", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 6ee666b2..8def76c5 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -14,6 +14,10 @@ GWModule, GWModuleBase, GWModuleBayesian, + GWModulePrediction, + broadcast_cycles, + cycle, + translation, ) from shimmer.modules.losses import ( BroadcastLossCoefs, @@ -29,7 +33,6 @@ SelectionBase, SingleDomainSelection, ) -from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations from shimmer.types import ( LatentsDomainGroupsDT, LatentsDomainGroupsT, @@ -54,18 +57,155 @@ class SchedulerArgs(TypedDict, total=False): class GWPredictionsBase(TypedDict): """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" - states: dict[str, torch.Tensor] + states: dict[frozenset[str], torch.Tensor] """ GW state representation from domain groups with only one domain. The key represent the domain's name. """ + broadcasts: dict[frozenset[str], dict[str, torch.Tensor]] + """ + broadcasts predictions of the model for each domain. It contains demi-cycles, + translations, and fused. + """ + + cycles: dict[frozenset[str], dict[str, torch.Tensor]] + """ + Cycle predictions of the model from one domain through another one. + """ + _T_gw_mod = TypeVar("_T_gw_mod", bound=GWModuleBase) _T_selection_mod = TypeVar("_T_selection_mod", bound=SelectionBase) _T_loss_mod = TypeVar("_T_loss_mod", bound=GWLossesBase) +def batch_demi_cycles( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, + latent_domains: LatentsDomainGroupsT, +) -> dict[str, torch.Tensor]: + """ + Computes demi-cycles of a batch of groups of domains. + + Args: + gw_mod (`GWModuleBase`): the GWModuleBase + selection_mod (`SelectionBase`): selection module + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[str, torch.Tensor]`: demi-cycles predictions for each domain. + """ + predictions: dict[str, torch.Tensor] = {} + for domains, latents in latent_domains.items(): + if len(domains) > 1: + continue + domain_name = list(domains)[0] + z = translation(gw_mod, selection_mod, latents, to=domain_name) + predictions[domain_name] = z + return predictions + + +def batch_cycles( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, + latent_domains: LatentsDomainGroupsT, + through_domains: Iterable[str], +) -> dict[tuple[str, str], torch.Tensor]: + """ + Computes cycles of a batch of groups of domains. + + Args: + gw_mod (`GWModuleBase`): GWModule to use for the cycle + selection_mod (`SelectionBase`): selection module + latent_domains (`LatentsT`): the batch of groups of domains + out_domains (`Iterable[str]`): iterable of domain names to do the cycle through. + Each domain will be done separetely. + + Returns: + `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each + couple of (start domain, intermediary domain). + """ + predictions: dict[tuple[str, str], torch.Tensor] = {} + for domains_source, latents_source in latent_domains.items(): + if len(domains_source) > 1: + continue + domain_name_source = next(iter(domains_source)) + for domain_name_through in through_domains: + if domain_name_source == domain_name_through: + continue + z = cycle( + gw_mod, selection_mod, latents_source, through=domain_name_through + ) + domains = (domain_name_source, domain_name_through) + predictions[domains] = z[domain_name_source] + return predictions + + +def batch_translations( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, + latent_domains: LatentsDomainGroupsT, +) -> dict[tuple[str, str], torch.Tensor]: + """ + Computes translations of a batch of groups of domains. + + Args: + gw_mod (`GWModuleBase`): GWModule to do the translation + selection_mod (`SelectionBase`): selection module + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[tuple[str, str], torch.Tensor]`: translation predictions for each + couple of (start domain, target domain). + """ + predictions: dict[tuple[str, str], torch.Tensor] = {} + for domains, latents in latent_domains.items(): + if len(domains) < 2: + continue + for domain_name_source in domains: + for domain_name_target in domains: + if domain_name_source == domain_name_target: + continue + prediction = translation( + gw_mod, + selection_mod, + {domain_name_source: latents[domain_name_source]}, + to=domain_name_target, + ) + predictions[(domain_name_source, domain_name_target)] = prediction + return predictions + + +def batch_broadcasts( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, + latent_domains: LatentsDomainGroupsT, +) -> tuple[ + dict[frozenset[str], dict[str, torch.Tensor]], + dict[frozenset[str], dict[str, torch.Tensor]], +]: + """ + Computes all possible broadcast of a batch for each group of domains. + + Args: + gw_mod (`GWModuleBase`): the GWModuleBase + selection_mod (`SelectionBase`): selection module + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `tuple[dict[frozenset[str], dict[str, torch.Tensor]], + dict[frozenset[str], dict[str, torch.Tensor]], ]`: broadcast predictions + for each domain.""" + predictions: dict[frozenset[str], dict[str, torch.Tensor]] = {} + cycles: dict[frozenset[str], dict[str, torch.Tensor]] = {} + for domains, latents in latent_domains.items(): + pred_broadcast, pred_cycles = broadcast_cycles(gw_mod, selection_mod, latents) + predictions[domains] = pred_broadcast + cycles[domains] = pred_cycles + return predictions, cycles + + class GlobalWorkspaceBase( Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule ): @@ -219,12 +359,22 @@ def forward( # type: ignore Returns: `GWPredictionsBase`: the predictions on the batch. """ + states: dict[frozenset[str], torch.Tensor] = {} + broadcasts: dict[frozenset[str], dict[str, torch.Tensor]] = {} + cycles: dict[frozenset[str], dict[str, torch.Tensor]] = {} + for domain_group, latent_group in latent_domains.items(): + predictions = cast( + GWModulePrediction, self.gw_mod(latent_group, self.selection_mod) + ) + states[domain_group] = predictions["states"] + broadcasts[domain_group] = predictions["broadcasts"] + cycles[domain_group] = predictions["cycles"] - return GWPredictionsBase(states=self.batch_gw_states(latent_domains)) + return GWPredictionsBase(states=states, broadcasts=broadcasts, cycles=cycles) def batch_gw_states( self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: + ) -> dict[frozenset[str], torch.Tensor]: """ Comptues GW states of a batch of groups of domains. @@ -234,15 +384,10 @@ def batch_gw_states( Returns: `dict[str, torch.Tensor]`: states for each domain. """ - predictions: dict[str, torch.Tensor] = {} + predictions: dict[frozenset[str], torch.Tensor] = {} for domains, latents in latent_domains.items(): - if len(domains) > 1: - continue - domain_name = list(domains)[0] - z = self.gw_mod.encode_and_fuse( - latents, selection_module=self.selection_mod - ) - predictions[domain_name] = z + z = self.gw_mod.encode_and_fuse(latents, self.selection_mod) + predictions[domains] = z return predictions def encode_domain(self, domain: Any, name: str) -> torch.Tensor: @@ -440,31 +585,6 @@ def freeze_domain_modules( return cast(dict[str, DomainModule], ModuleDict(domain_mods)) -class GWPredictions(GWPredictionsBase): - """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" - - demi_cycles: dict[str, torch.Tensor] - """ - Demi-cycle predictions of the model for each domain. Only computed on domain - groups with only one domain. - """ - - cycles: dict[tuple[str, str], torch.Tensor] - """ - Cycle predictions of the model from one domain through another one. - Only computed on domain groups with more than one domain. - The keys are tuple with start domain and intermediary domain. - """ - - translations: dict[tuple[str, str], torch.Tensor] - """ - Translation predictions of the model from one domain through another one. - - Only computed on domain groups with more than one domain. - The keys are tuples with start domain and target domain. - """ - - class GlobalWorkspace2Domains( GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains] ): @@ -533,32 +653,6 @@ def __init__( scheduler_args, ) - def forward( # type: ignore - self, - latent_domains: LatentsDomainGroupsT, - ) -> GWPredictions: - """ - Computes demi-cycles, cycles, and translations. - - Args: - latent_domains (`LatentsT`): Groups of domains for the computation. - - Returns: - `GWPredictions`: the predictions on the batch. - """ - return GWPredictions( - demi_cycles=batch_demi_cycles( - self.gw_mod, self.selection_mod, latent_domains - ), - cycles=batch_cycles( - self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() - ), - translations=batch_translations( - self.gw_mod, self.selection_mod, latent_domains - ), - **super().forward(latent_domains), - ) - class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. @@ -629,33 +723,6 @@ def __init__( scheduler_args, ) - def forward( # type: ignore - self, - latent_domains: LatentsDomainGroupsT, - ) -> GWPredictions: - """ - Computes demi-cycles, cycles, and translations. - - Args: - latent_domains (`LatentsT`): Groups of domains for the computation. - - Returns: - `GWPredictions`: the predictions on the batch. - """ - return GWPredictions( - demi_cycles=batch_demi_cycles( - self.gw_mod, self.selection_mod, latent_domains - ), - cycles=batch_cycles( - self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() - ), - translations=batch_translations( - self.gw_mod, self.selection_mod, latent_domains - ), - # TODO: add other combinations - **super().forward(latent_domains), - ) - class GlobalWorkspaceBayesian( GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian] @@ -751,32 +818,6 @@ def __init__( scheduler_args, ) - def forward( # type: ignore - self, - latent_domains: LatentsDomainGroupsT, - ) -> GWPredictions: - """ - Computes demi-cycles, cycles, and translations. - - Args: - latent_domains (`LatentsT`): Groups of domains for the computation. - - Returns: - `GWPredictions`: the predictions on the batch. - """ - return GWPredictions( - demi_cycles=batch_demi_cycles( - self.gw_mod, self.selection_mod, latent_domains - ), - cycles=batch_cycles( - self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() - ), - translations=batch_translations( - self.gw_mod, self.selection_mod, latent_domains - ), - **super().forward(latent_domains), - ) - def pretrained_global_workspace( checkpoint_path: str | Path, diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 3dd1b600..489b8125 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -1,13 +1,122 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping -from typing import cast +from typing import TypedDict, cast import torch from torch import nn from shimmer.modules.domain import DomainModule from shimmer.modules.selection import SelectionBase -from shimmer.types import LatentsDomainGroupDT, LatentsDomainGroupT +from shimmer.types import ( + LatentsDomainGroupDT, + LatentsDomainGroupT, +) + + +def translation( + gw_module: "GWModuleBase", + selection_mod: SelectionBase, + x: LatentsDomainGroupT, + to: str, +) -> torch.Tensor: + """ + Translate from multiple domains to one domain. + + Args: + gw_module (`"GWModuleBase"`): GWModule to perform the translation over + selection_mod (`SelectionBase`): selection module + x (`LatentsDomainGroupT`): the group of latent representations + to (`str`): the domain name to encode to + + Returns: + `torch.Tensor`: the translated unimodal representation + of the provided domain. + """ + return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[ + to + ] + + +def cycle( + gw_module: "GWModuleBase", + selection_mod: SelectionBase, + x: LatentsDomainGroupT, + through: str, +) -> LatentsDomainGroupDT: + """ + Do a full cycle from a group of representation through one domain. + + [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains] + + Args: + gw_module (`"GWModuleBase"`): GWModule to perform the translation over + selection_mod (`SelectionBase`): selection module + x (`LatentsDomainGroupT`): group of unimodal latent representation + through (`str`): domain name to cycle through + Returns: + `LatentsDomainGroupDT`: group of unimodal latent representation after + cycling. + """ + return { + domain: translation( + gw_module, + selection_mod, + {through: translation(gw_module, selection_mod, x, through)}, + domain, + ) + for domain in x + } + + +def broadcast( + gw_mod: "GWModuleBase", + selection_mod: SelectionBase, + latents: LatentsDomainGroupT, +) -> dict[str, torch.Tensor]: + """ + broadcast a group + + Args: + gw_mod (`"GWModuleBase"`): GWModule to perform the translation over + selection_mod (`SelectionBase`): selection module + latents (`LatentsDomainGroupT`): the group of latent representations + + Returns: + `torch.Tensor`: the broadcast representation + """ + predictions: dict[str, torch.Tensor] = {} + state = gw_mod.encode_and_fuse(latents, selection_mod) + all_domains = list(gw_mod.domain_mods.keys()) + for domain in all_domains: + predictions[domain] = gw_mod.decode(state, domains=[domain])[domain] + return predictions + + +def broadcast_cycles( + gw_mod: "GWModuleBase", + selection_mod: SelectionBase, + latents: LatentsDomainGroupT, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """ + broadcast a group + + Args: + gw_mod (`"GWModuleBase"`): GWModule to perform the translation over + selection_mod (`SelectionBase`): selection module + latents (`LatentsDomainGroupT`): the group of latent representations + + Returns: + `torch.Tensor`: the broadcast representation + """ + all_domains = list(latents.keys()) + predictions = broadcast(gw_mod, selection_mod, latents) + inverse = { + name: latent for name, latent in predictions.items() if name not in all_domains + } + cycles: dict[str, torch.Tensor] = {} + if len(inverse): + cycles = broadcast(gw_mod, selection_mod, inverse) + return predictions, cycles def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]: @@ -107,6 +216,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.tanh(super().forward(input)) +class GWModulePrediction(TypedDict): + """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" + + states: torch.Tensor + """ + GW state representation from domain groups with only one domain. + The key represent the domain's name. + """ + + broadcasts: dict[str, torch.Tensor] + """ + broadcasts predictions of the model for each domain. It contains demi-cycles, + translations, and fused. + """ + + cycles: dict[str, torch.Tensor] + """ + Cycle predictions of the model from one domain through another one. + """ + + class GWModuleBase(nn.Module, ABC): """ Base class for GWModule. @@ -205,6 +335,29 @@ def decode( """ ... + def forward( + self, + latent_domains: LatentsDomainGroupT, + selection_module: SelectionBase, + ) -> GWModulePrediction: + """ + Computes demi-cycles, cycles, and translations. + + Args: + latent_domains (`LatentsDomainGroupT`): Group of domains + selection_module (`SelectionBase`): selection module + + Returns: + `GWModulePredictions`: the predictions on the group. + """ + broadcasts, cycles = broadcast_cycles(self, selection_module, latent_domains) + + return GWModulePrediction( + states=self.encode_and_fuse(latent_domains, selection_module), + broadcasts=broadcasts, + cycles=cycles, + ) + class GWModule(GWModuleBase): """GW nn.Module. Implements `GWModuleBase`.""" diff --git a/shimmer/modules/utils.py b/shimmer/modules/utils.py deleted file mode 100644 index 8bdb5727..00000000 --- a/shimmer/modules/utils.py +++ /dev/null @@ -1,163 +0,0 @@ -from collections.abc import Iterable - -import torch - -from shimmer.modules.gw_module import GWModuleBase -from shimmer.modules.selection import SelectionBase -from shimmer.types import ( - LatentsDomainGroupDT, - LatentsDomainGroupsT, - LatentsDomainGroupT, -) - - -def translation( - gw_module: GWModuleBase, - selection_mod: SelectionBase, - x: LatentsDomainGroupT, - to: str, -) -> torch.Tensor: - """ - Translate from multiple domains to one domain. - - Args: - gw_module (`GWModuleBase`): GWModule to perform the translation over - selection_mod (`SelectionBase`): selection module - x (`LatentsDomainGroupT`): the group of latent representations - to (`str`): the domain name to encode to - - Returns: - `torch.Tensor`: the translated unimodal representation - of the provided domain. - """ - return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[ - to - ] - - -def cycle( - gw_module: GWModuleBase, - selection_mod: SelectionBase, - x: LatentsDomainGroupT, - through: str, -) -> LatentsDomainGroupDT: - """ - Do a full cycle from a group of representation through one domain. - - [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains] - - Args: - gw_module (`GWModuleBase`): GWModule to perform the translation over - selection_mod (`SelectionBase`): selection module - x (`LatentsDomainGroupT`): group of unimodal latent representation - through (`str`): domain name to cycle through - Returns: - `LatentsDomainGroupDT`: group of unimodal latent representation after - cycling. - """ - return { - domain: translation( - gw_module, - selection_mod, - {through: translation(gw_module, selection_mod, x, through)}, - domain, - ) - for domain in x - } - - -def batch_demi_cycles( - gw_mod: GWModuleBase, - selection_mod: SelectionBase, - latent_domains: LatentsDomainGroupsT, -) -> dict[str, torch.Tensor]: - """ - Computes demi-cycles of a batch of groups of domains. - - Args: - gw_mod (`GWModuleBase`): the GWModuleBase - selection_mod (`SelectionBase`): selection module - latent_domains (`LatentsT`): the batch of groups of domains - - Returns: - `dict[str, torch.Tensor]`: demi-cycles predictions for each domain. - """ - predictions: dict[str, torch.Tensor] = {} - for domains, latents in latent_domains.items(): - if len(domains) > 1: - continue - domain_name = list(domains)[0] - z = translation(gw_mod, selection_mod, latents, to=domain_name) - predictions[domain_name] = z - return predictions - - -def batch_cycles( - gw_mod: GWModuleBase, - selection_mod: SelectionBase, - latent_domains: LatentsDomainGroupsT, - through_domains: Iterable[str], -) -> dict[tuple[str, str], torch.Tensor]: - """ - Computes cycles of a batch of groups of domains. - - Args: - gw_mod (`GWModuleBase`): GWModule to use for the cycle - selection_mod (`SelectionBase`): selection module - latent_domains (`LatentsT`): the batch of groups of domains - out_domains (`Iterable[str]`): iterable of domain names to do the cycle through. - Each domain will be done separetely. - - Returns: - `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each - couple of (start domain, intermediary domain). - """ - predictions: dict[tuple[str, str], torch.Tensor] = {} - for domains_source, latents_source in latent_domains.items(): - if len(domains_source) > 1: - continue - domain_name_source = next(iter(domains_source)) - for domain_name_through in through_domains: - if domain_name_source == domain_name_through: - continue - z = cycle( - gw_mod, selection_mod, latents_source, through=domain_name_through - ) - domains = (domain_name_source, domain_name_through) - predictions[domains] = z[domain_name_source] - return predictions - - -def batch_translations( - gw_mod: GWModuleBase, - selection_mod: SelectionBase, - latent_domains: LatentsDomainGroupsT, -) -> dict[tuple[str, str], torch.Tensor]: - """ - Computes translations of a batch of groups of domains. - - Args: - gw_mod (`GWModuleBase`): GWModule to do the translation - selection_mod (`SelectionBase`): selection module - latent_domains (`LatentsT`): the batch of groups of domains - - Returns: - `dict[tuple[str, str], torch.Tensor]`: translation predictions for each - couple of (start domain, target domain). - """ - predictions: dict[tuple[str, str], torch.Tensor] = {} - for domains, latents in latent_domains.items(): - if len(domains) < 2: - continue - for domain_name_source in domains: - for domain_name_target in domains: - if domain_name_source == domain_name_target: - continue - prediction = translation( - gw_mod, - selection_mod, - {domain_name_source: latents[domain_name_source]}, - to=domain_name_target, - ) - predictions[(domain_name_source, domain_name_target)] = prediction - return predictions