diff --git a/.gitignore b/.gitignore index 10622465..5740cb08 100644 --- a/.gitignore +++ b/.gitignore @@ -152,6 +152,9 @@ dmypy.json # Cython debug symbols cython_debug/ +# vim +*.swp + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/shimmer/__init__.py b/shimmer/__init__.py index cd67bb43..be590d54 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -22,6 +22,7 @@ GWModuleWithUncertainty, ) from shimmer.modules.losses import ( + BroadcastLossCoefs, GWLosses, GWLossesBase, GWLossesWithUncertainty, @@ -82,6 +83,7 @@ "contrastive_loss", "ContrastiveLoss", "LossCoefs", + "BroadcastLossCoefs", "GWLossesBase", "GWLosses", "GWLossesWithUncertainty", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 472c3e06..b7c6c435 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -25,6 +25,7 @@ GWModuleWithUncertainty, ) from shimmer.modules.losses import ( + BroadcastLossCoefs, GWLosses, GWLossesBase, GWLossesWithUncertainty, @@ -73,6 +74,7 @@ "contrastive_loss_with_uncertainty", "ContrastiveLossWithUncertainty", "LossCoefs", + "BroadcastLossCoefs", "GWLossesBase", "GWLosses", "GWLossesWithUncertainty", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a7ce2709..45174511 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -13,17 +13,21 @@ from shimmer.modules.gw_module import ( GWModule, GWModuleBase, - GWModuleFusion, GWModuleWithUncertainty, ) from shimmer.modules.losses import ( + BroadcastLossCoefs, GWLosses, GWLossesBase, GWLossesFusion, GWLossesWithUncertainty, LossCoefs, ) -from shimmer.modules.selection import SelectionBase, SingleDomainSelection +from shimmer.modules.selection import ( + RandomSelection, + SelectionBase, + SingleDomainSelection, +) from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations from shimmer.types import ( LatentsDomainGroupsDT, @@ -651,6 +655,8 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, + loss_coefs: BroadcastLossCoefs, + selection_temperature: float = 0.2, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -671,6 +677,9 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. + loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. + selection_temperature (`float`): temperature value for the RandomSelection + module. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -681,16 +690,17 @@ def __init__( contrastive losses. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModuleFusion(domain_mods, workspace_dim, gw_encoders, gw_decoders) + gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale ) - # TODO: use the correction selection module - selection_mod = SingleDomainSelection() - loss_mod = GWLossesFusion(gw_mod, selection_mod, domain_mods, contrastive_loss) + selection_mod = RandomSelection(selection_temperature) + loss_mod = GWLossesFusion( + gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss + ) super().__init__( gw_mod, diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 2e043f91..96b02d5c 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -97,7 +97,7 @@ def __init__( super().__init__(in_dim, hidden_dim, out_dim, n_layers) def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.tanh(super().forward(input)) + return super().forward(input) class GWEncoderLinear(nn.Linear): @@ -252,14 +252,16 @@ def fuse( Returns: `torch.Tensor`: The merged representation. """ - return torch.sum( - torch.stack( - [ - selection_scores[domain].unsqueeze(1) * x[domain] - for domain in selection_scores - ] - ), - dim=0, + return torch.tanh( + torch.sum( + torch.stack( + [ + selection_scores[domain].unsqueeze(1) * x[domain] + for domain in selection_scores + ] + ), + dim=0, + ) ) def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: @@ -364,7 +366,9 @@ def _fuse_and_scores( coef = final_scores.sum(dim=0) final_scores = final_scores / coef - return torch.sum(final_scores * torch.stack(domains), dim=0), final_scores + return torch.tanh( + torch.sum(final_scores * torch.stack(domains), dim=0) + ), final_scores def fuse( self, @@ -406,96 +410,3 @@ def fuse( `torch.Tensor`: The merged representation. """ return self._fuse_and_scores(x, selection_scores)[0] - - -class GWModuleFusion(GWModuleBase): - """ - GWModule used for fusion. - """ - - def __init__( - self, - domain_modules: Mapping[str, DomainModule], - workspace_dim: int, - gw_encoders: Mapping[str, nn.Module], - gw_decoders: Mapping[str, nn.Module], - ) -> None: - """ - Initializes the GWModule Fusion. - - Args: - domain_modules (`Mapping[str, DomainModule]`): the domain modules. - workspace_dim (`int`): dimension of the GW. - gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a an torch.nn.Module class that encodes a - unimodal latent representations into a GW representation (pre fusion). - gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a an torch.nn.Module class that decodes a - GW representation to a unimodal latent representation. - """ - super().__init__(domain_modules, workspace_dim) - - self.gw_encoders = nn.ModuleDict(gw_encoders) - """The module's encoders""" - - self.gw_decoders = nn.ModuleDict(gw_decoders) - """The module's decoders""" - - def fuse( - self, - x: LatentsDomainGroupT, - selection_scores: Mapping[str, torch.Tensor], - ) -> torch.Tensor: - """ - Merge function used to combine domains. - - Args: - x (`LatentsDomainGroupT`): the group of latent representation. - selection_score (`Mapping[str, torch.Tensor]`): attention scores to - use to encode the reprensetation. - Returns: - `torch.Tensor`: The merged representation. - """ - return torch.sum( - torch.stack( - [ - selection_scores[domain].unsqueeze(1) * x[domain] - for domain in selection_scores - ] - ), - dim=0, - ) - - def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: - """ - Encode the unimodal latent representation `x` into the pre-fusion GW - representations. - - Args: - x (`LatentsDomainGroupT`): the group of latent representation. - - Returns: - `torch.Tensor`: encoded and fused GW representation. - """ - return { - domain_name: self.gw_encoders[domain_name](domain) - for domain_name, domain in x.items() - } - - def decode( - self, z: torch.Tensor, domains: Iterable[str] | None = None - ) -> LatentsDomainGroupDT: - """ - Decodes a GW representation to multiple domains. - - Args: - z (`torch.Tensor`): the GW representation - domains (`Iterable[str] | None`): the domains to decode to. Defaults to - use keys in `gw_interfaces` (all domains). - Returns: - `LatentsDomainGroupDT`: decoded unimodal representation - """ - return { - domain: self.gw_decoders[domain](z) - for domain in domains or self.gw_decoders.keys() - } diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 3487f19c..b860a849 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,20 +1,13 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import TypedDict +from collections.abc import Generator, Mapping +from itertools import product +from typing import Any, TypedDict import torch -import torch.nn.functional as F -from shimmer.modules.contrastive_loss import ( - ContrastiveLossType, -) +from shimmer.modules.contrastive_loss import ContrastiveLossType from shimmer.modules.domain import DomainModule, LossOutput -from shimmer.modules.gw_module import ( - GWModule, - GWModuleBase, - GWModuleFusion, - GWModuleWithUncertainty, -) +from shimmer.modules.gw_module import GWModule, GWModuleBase, GWModuleWithUncertainty from shimmer.modules.selection import SelectionBase from shimmer.types import LatentsDomainGroupsT, ModelModeT @@ -642,194 +635,233 @@ def step( return LossOutput(loss, metrics) -def sample_scaling_factors( - binary_scaling_prob: float, - batch_size: int, - temperature: float, - device: torch.device, -): +def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]: """ - Args: - binary_scaling_prob (`float`): Should be between 0 and 1. - batch_size (`int`): - temperature (`float`): Should be greater than 0. - device (`torch.device`): - """ - assert 0 <= binary_scaling_prob <= 1 - - # TODO: make selection deterministic - binary_mask = torch.rand(batch_size) < binary_scaling_prob + Generates all possible partitions of zeros and ones for `n` elements, + excluding the all-zeros partition. - binary_factors = torch.randint(0, 2, (batch_size,)).float() - binary_softmax = torch.stack([binary_factors, 1 - binary_factors], dim=1) + Args: + n (`int`): The number of modalities to generate partitions for. - uniform_samples = torch.rand(batch_size) - uniform_for_softmax = torch.stack([uniform_samples, 1 - uniform_samples], dim=1) + Yields: + `tuple[int, ...]`: A partition of zeros and ones, excluding the + all-zeros partition. + """ + for perm in product([0, 1], repeat=n): + if any(perm): + yield perm - uniform_softmax = F.softmax(uniform_for_softmax * temperature, dim=1) - scaling_factors = torch.where( - binary_mask.unsqueeze(-1), binary_softmax, uniform_softmax - ).to(device) +class BroadcastLossCoefs(TypedDict, total=False): + """ + Dict of loss coefficients used in the GWLossesFusion. - binary_indices = torch.where(binary_mask)[0] - softmax_indices = torch.where(~binary_mask)[0] + If one is not provided, the coefficient is assumed to be 0 and will not be logged. + If the loss is excplicitely set to 0, it will be logged, but not take part in + the total loss. + """ - binary_scaling_factors = scaling_factors[binary_indices] - softmax_scaling_factors = scaling_factors[softmax_indices] + contrastives: float + """Contrastive loss coefficient.""" - return { - "binary": ( - binary_scaling_factors[:, 0], - binary_scaling_factors[:, 1], - binary_indices, - ), - "softmax": ( - softmax_scaling_factors[:, 0], - softmax_scaling_factors[:, 1], - softmax_indices, - ), - } + broadcast: float + """Broadcast loss coefficient.""" class GWLossesFusion(GWLossesBase): + """ + Implementation of `GWLossesBase` for fusion-based models. + """ + def __init__( self, - gw_mod: GWModuleFusion, + gw_mod: GWModule, selection_mod: SelectionBase, domain_mods: dict[str, DomainModule], + loss_coefs: BroadcastLossCoefs, contrastive_fn: ContrastiveLossType, ): + """ + Initializes the loss computation module for a Global Workspace Fusion model. + + Args: + gw_mod: The GWModule for the global workspace. + selection_mod: The selection mechanism for the model. + domain_mods: A mapping of domain names to their respective DomainModule. + loss_coefs (`BroadcastLossCoefs`): coefs for the losses + contrastive_fn: The function used for computing contrastive loss. + """ super().__init__() self.gw_mod = gw_mod self.selection_mod = selection_mod self.domain_mods = domain_mods + self.loss_coefs = loss_coefs self.contrastive_fn = contrastive_fn - def demi_cycle_loss( + def contrastive_loss( self, latent_domains: LatentsDomainGroupsT ) -> dict[str, torch.Tensor]: - return demi_cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains - ) + """ + Computes the contrastive loss for the given latent domains. - def cycle_loss( - self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: - return cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains - ) + Args: + latent_domains: The latent domain representations. - def translation_loss( - self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: - return translation_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains - ) + Returns: + A dictionary of contrastive loss metrics. + """ - def contrastive_loss( - self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( self, latent_domains: LatentsDomainGroupsT, mode: ModelModeT ) -> dict[str, torch.Tensor]: - losses: dict[str, torch.Tensor] = {} - metrics: dict[str, torch.Tensor] = {} + """ + Computes broadcast loss including demi-cycle, cycle, and translation losses. - for latents in latent_domains.values(): - if len(latents) < 2: - continue - batch_size = latents[next(iter(latents))].size(0) - device = latents[next(iter(latents))].device - - # TODO: don't hardcode the proportions - # (first param of the sample_scaling_factors function) - - if mode == "val": - scaling_factors = sample_scaling_factors(0.5, batch_size, 5.0, device) - else: - scaling_factors = sample_scaling_factors(0.9, batch_size, 5.0, device) - - for scale_type, ( - scaling_factor_1, - scaling_factor_2, - indices, - ) in scaling_factors.items(): - scaled_latents = {} - - for i, (domain_name, latent) in enumerate(latents.items()): - scaling_factor = scaling_factor_1 if i == 0 else scaling_factor_2 - scaled_latents_subset = latent[indices] * scaling_factor.unsqueeze( - -1 - ) - scaled_latents_subset = scaled_latents_subset.to(latent) + Args: + latent_domains: The latent domain representations. + mode: The mode of the model (e.g., 'train', 'eval'). - scaled_latents[domain_name] = scaled_latents_subset + Returns: + A dictionary with the total loss and additional metrics. + """ + losses: dict[str, torch.Tensor] = {} + metrics: dict[str, Any] = {} + + demi_cycle_losses: list[str] = [] + cycle_losses: list[str] = [] + translation_losses: list[str] = [] + + for group_domains, latents in latent_domains.items(): + encoded_latents = self.gw_mod.encode(latents) + partitions = generate_partitions(len(group_domains)) + domain_names = list(latents) + + for partition in partitions: + selected_latents = { + domain: latents[domain] + for domain, present in zip(domain_names, partition, strict=True) + if present + } + selected_encoded_latents = { + domain: encoded_latents[domain] for domain in selected_latents + } + selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}" - encoded_latents_for_subset = self.gw_mod.encode_and_fuse( - scaled_latents, self.selection_mod + selection_scores = self.selection_mod( + selected_latents, selected_encoded_latents ) - encoded_latents_for_subset = torch.tanh(encoded_latents_for_subset) - decoded_latents_for_subset = self.gw_mod.decode( - encoded_latents_for_subset + fused_latents = self.gw_mod.fuse( + selected_encoded_latents, selection_scores ) + decoded_latents = self.gw_mod.decode(fused_latents) - for domain_name in latents: - domain_mod = self.domain_mods[domain_name] - decoded_latent_for_domain_subset = decoded_latents_for_subset[ - domain_name - ] - original_latent_for_domain_subset = latents[domain_name][indices] - loss_output = domain_mod.compute_broadcast_loss( - decoded_latent_for_domain_subset, - original_latent_for_domain_subset, - ) - loss_key = f"{domain_name}_loss_{scale_type}" + num_active_domains = sum(partition) + num_total_domains = len(partition) + for domain, pred in decoded_latents.items(): + if domain not in group_domains: # if we don't have ground truth + continue + ground_truth = latents[domain] + loss_output = self.domain_mods[domain].compute_loss( + pred, ground_truth + ) + loss_label = f"from_{selected_group_label}_to_{domain}" + losses[loss_label + "_loss"] = loss_output.loss metrics.update( - { - f"broadcast_{loss_key}_{k}": v - for k, v in loss_output.metrics.items() - } + {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()} ) - losses[loss_key] = loss_output.loss.mean() - binary_count = scaling_factors["binary"][2].size(0) - softmax_count = scaling_factors["softmax"][2].size(0) - total_count = binary_count + softmax_count + if num_active_domains == 1 and domain in selected_latents: + demi_cycle_losses.append(loss_label + "_loss") + if num_active_domains == 1 and domain not in selected_latents: + translation_losses.append(loss_label + "_loss") - for domain_name in latents: - full_loss_key = f"{domain_name}_full_loss" + if num_active_domains < num_total_domains: + inverse_selected_latents = { + domain: decoded_latents[domain] + for domain in decoded_latents + if domain not in selected_latents + } - binary_loss_key = f"{domain_name}_loss_binary" - softmax_loss_key = f"{domain_name}_loss_softmax" + inverse_selected_group_label = ( + "{" + ", ".join(sorted(inverse_selected_latents)) + "}" + ) - binary_loss = losses[binary_loss_key] * (binary_count / total_count) - softmax_loss = losses[softmax_loss_key] * (softmax_count / total_count) + re_encoded_latents = self.gw_mod.encode(inverse_selected_latents) + re_selection_scores = self.selection_mod( + inverse_selected_latents, re_encoded_latents + ) + re_fused_latents = self.gw_mod.fuse( + re_encoded_latents, re_selection_scores + ) + re_decoded_latents = self.gw_mod.decode( + re_fused_latents, domains=selected_latents.keys() + ) - losses[full_loss_key] = binary_loss + softmax_loss + for domain in selected_latents: + re_ground_truth = latents[domain] + re_loss_output = self.domain_mods[domain].compute_loss( + re_decoded_latents[domain], re_ground_truth + ) + loss_label = ( + f"from_{selected_group_label}_" + f"through_{inverse_selected_group_label}_to_{domain}" + ) + losses[loss_label + "_loss"] = re_loss_output.loss + metrics.update( + { + f"{loss_label}_{k}": v + for k, v in re_loss_output.metrics.items() + } + ) + cycle_losses.append(loss_label + "_loss") + + if demi_cycle_losses: + metrics["demi_cycles"] = torch.mean( + torch.stack([losses[loss_name] for loss_name in demi_cycle_losses]) + ) + if cycle_losses: + metrics["cycles"] = torch.mean( + torch.stack([losses[loss_name] for loss_name in cycle_losses]) + ) + if translation_losses: + metrics["translations"] = torch.mean( + torch.stack([losses[loss_name] for loss_name in translation_losses]) + ) - losses["broadcast"] = torch.stack( - [loss for name, loss in losses.items() if "full_loss" in name], dim=0 - ).mean() - losses.update(metrics) - return losses + total_loss = torch.mean(torch.stack(list(losses.values()))) + return {"broadcast": total_loss, **metrics} def step( self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT, ) -> LossOutput: + """ + Performs a step of loss computation. + + Args: + domain_latents: Latent representations for all domains. + mode: The mode in which the model is currently operating. + + Returns: + A LossOutput object containing the loss and metrics for this step. + """ + metrics: dict[str, torch.Tensor] = {} - metrics.update(self.demi_cycle_loss(domain_latents)) - metrics.update(self.cycle_loss(domain_latents)) - metrics.update(self.translation_loss(domain_latents)) metrics.update(self.contrastive_loss(domain_latents)) metrics.update(self.broadcast_loss(domain_latents, mode)) - loss = metrics["broadcast"] + metrics["contrastives"] + loss = torch.stack( + [ + metrics[name] * coef + for name, coef in self.loss_coefs.items() + if isinstance(coef, float) and coef > 0 + ], + dim=0, + ).mean() return LossOutput(loss, metrics) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index a1088ade..a0b3cfa5 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -164,69 +164,51 @@ def forward( class RandomSelection(SelectionBase): """ - random attention, not learned, with a proportion of binary scaling factors, - and a proportion of uniform-then-softmaxed-across-modalities scores. - this class serves to train broadcast with robustness on linear scaling on - prefusion representations. + Modified random attention to only utilize uniform-softmax scores across modalities. + This version omits the binary scaling factors and focuses on generating attention + coefficients using a uniform distribution followed by a domain-wise softmax. """ - def __init__(self, binary_proportion: float, temperature: float): + def __init__(self, temperature: float): """ Args: - binary_proportion (`float`) : proportion of binary scaling factors - returned by forward(). between 0 and 1. - temperature (`float`) : temperature of the softmax applied to uniform + temperature (`float`): Temperature of the softmax applied to uniform scaling factors. """ super().__init__() - self.binary_proportion = binary_proportion self.temperature = temperature def forward( self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT ) -> dict[str, torch.Tensor]: """ - randomly draw binary and uniform-then-domain-wise-softmaxed samples according - to self.binary_proportion. + Generate uniform-then-domain-wise-softmaxed samples for each domain. Args: domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - This is not used in the function. + This is not used in the function directly but determines the structure + of the returned attention coefficients. Returns: - `dict[str, torch.Tensor]`: for each domain in the group, the fusion - coefficient for each item in the batch. + `dict[str, torch.Tensor]`: For each domain in the group, the fusion + coefficient for each item in the batch, based solely on + uniform-softmax scores. """ num_domains = len(domains) batch_size = group_batch_size(domains) - # have to add extra binaries when the division's not integer - total_binary_scores = int(batch_size * self.binary_proportion) - num_binary_per_domain, extra_binary_scores = divmod( - total_binary_scores, num_domains - ) - - # Calculate number of uniform scores taking into account extra binary scores - num_uniform = batch_size - total_binary_scores + # Generate uniform scores + uniform_scores = torch.rand(batch_size, num_domains) - uniform_scores = torch.rand(num_uniform, num_domains) + # Apply softmax across domains with temperature scaling softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) - # Generate binary scores, adjusting for any extra binary scores - binary_scores = [] - for i in range(num_domains): - binary_score = torch.zeros( - num_binary_per_domain + (1 if i < extra_binary_scores else 0), - num_domains, - ) - binary_score[:, i] = 1 - binary_scores.append(binary_score) - binary_scores_concat = torch.cat(binary_scores, dim=0) - - all_scores = torch.cat([softmax_scores, binary_scores_concat], dim=0) + # Create attention dictionary for each domain + device = group_device(domains) attention_dict = { - domain: all_scores[:, i : i + 1] for i, domain in enumerate(domains) + domain: softmax_scores[:, i].to(device) for i, domain in enumerate(domains) } + return attention_dict @@ -292,7 +274,7 @@ def fuse_weighted_encodings( return summed_tensor def forward( - self, domains: LatentsDomainGroupT, encodings: LatentsDomainGroupT + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT ) -> dict[str, torch.Tensor]: """ Compute keys and queries, match them with dot product and softmax. @@ -323,7 +305,9 @@ def forward( static_attention_dict = self.calculate_attention_dict(keys, query) # Apply the attention scores to the encodings - summed_tensor = self.fuse_weighted_encodings(encodings, static_attention_dict) + summed_tensor = self.fuse_weighted_encodings( + encodings_pre_fusion, static_attention_dict + ) # Retrieve query (now it is dependent on the new gw state) query = self.query_layer(summed_tensor.to(device)) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py new file mode 100644 index 00000000..20a58193 --- /dev/null +++ b/tests/test_broadcast.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +from torch.nn.functional import cross_entropy, normalize + +from shimmer.modules.contrastive_loss import ContrastiveLossType +from shimmer.modules.domain import DomainModule, LossOutput +from shimmer.modules.global_workspace import GlobalWorkspaceFusion +from shimmer.modules.losses import BroadcastLossCoefs + + +def contrastive_loss(x: torch.Tensor, y: torch.Tensor) -> LossOutput: + """ + Simplified CLIP-like contrastive loss that matches the expected signature. + + Args: + x (torch.Tensor): Predictions. + y (torch.Tensor): Targets. + + Returns: + LossOutput: A dataclass containing the computed loss and optionally + additional metrics. + """ + # Assuming logit_scale is a pre-defined tensor if needed for the calculation + # For the sake of matching the function signature, + # we'll remove it from the parameters + # Similarly, we assume a fixed reduction mode for simplicity + logit_scale = torch.tensor( + 1.0 + ) # Placeholder for an actual logit scale if necessary + reduction = "mean" # Fixed reduction mode + + xn = normalize(x, dim=-1) + yn = normalize(y, dim=-1) + logits = torch.matmul(xn, yn.t()) + labels = torch.arange(xn.size(0), device=xn.device) + ce_loss = 0.5 * ( + cross_entropy(logits * logit_scale.exp(), labels, reduction=reduction) + + cross_entropy(logits.t() * logit_scale.exp(), labels, reduction=reduction) + ) + + return LossOutput(loss=ce_loss) + + +class DummyDomainModule(DomainModule): + def __init__(self, latent_dim: int): + super().__init__(latent_dim) + self.encoder = nn.Linear(latent_dim, latent_dim) # Simplified encoder + self.decoder = nn.Linear(latent_dim, latent_dim) # Simplified decoder + + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) # Simple forward pass through encoder + + def decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) # Simple forward pass through decoder + + def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + loss = torch.mean((pred - target) ** 2) # Simple MSE loss + return LossOutput(loss=loss) # Constructing LossOutput with the loss + + +def setup_global_workspace_fusion() -> GlobalWorkspaceFusion: + """ + Setting up the test environment for GlobalWorkspaceFusion + """ + domain_mods: dict[str, DomainModule] = { + "domain1": DummyDomainModule(latent_dim=10), + "domain2": DummyDomainModule(latent_dim=10), + } + gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} + gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} + workspace_dim = 10 + contrastive_fn: ContrastiveLossType = contrastive_loss + loss_coefs: BroadcastLossCoefs = {"broadcast": 1.0, "contrastives": 0.1} + + gw_fusion = GlobalWorkspaceFusion( + domain_mods, + gw_encoders, + gw_decoders, + workspace_dim, + loss_coefs, + selection_temperature=0.2, + optim_lr=1e-3, + optim_weight_decay=0.0, + scheduler_args=None, # Simplified for testing + learn_logit_scale=False, + contrastive_loss=contrastive_fn, + ) + + return gw_fusion + + +def test_broadcast_loss(): + gw_fusion = setup_global_workspace_fusion() + + # Adjusting the dummy data to fit the expected input structure for broadcast_loss + # Now using a frozenset for the keys to match LatentsDomainGroupsT + latent_domains = { + frozenset(["domain1", "domain2"]): { + "domain1": torch.rand(5, 10), # Batch size of 5, feature dimension of 10 + "domain2": torch.rand(5, 10), + } + } + + # Test broadcast_loss with the corrected structure + output = gw_fusion.loss_mod.broadcast_loss(latent_domains, "train") + print(output) + + +# Call the test function to execute the test +test_broadcast_loss() diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py index c887898e..5422fb73 100644 --- a/tests/test_random_attention.py +++ b/tests/test_random_attention.py @@ -1,16 +1,14 @@ -import numpy as np import torch from shimmer.modules.selection import RandomSelection def test_multiple_domains(): - binary_proportion = 0.5 temperature = 1.0 domain_dim = 12 batch_size = 2056 - selection = RandomSelection(binary_proportion, temperature) + selection = RandomSelection(temperature) multiple_domain_input = { "v_latents": torch.rand(batch_size, domain_dim), "attr": torch.rand(batch_size, domain_dim), @@ -37,12 +35,11 @@ def test_multiple_domains(): def test_three_domains(): - binary_proportion = 0.5 temperature = 1.0 domain_dim = 12 batch_size = 2056 - selection = RandomSelection(binary_proportion, temperature) + selection = RandomSelection(temperature) three_domain_input = { "v_latents": torch.rand(batch_size, domain_dim), "attr": torch.rand(batch_size, domain_dim), @@ -61,14 +58,9 @@ def test_three_domains(): for domain in three_domain_input: assert selection_scores[domain].shape == ( batch_size, - 1, ), f"Scores shape mismatch for {domain}" - # Check if the binary scores are as expected - # This part might need adjustments based on how binary scores are distributed - # and combined with uniform scores in your actual implementation - - # Check if the sum of selection scores across domains equals 1 + # Ensure the sum of attention scores across domains equals 1 scores_sum = sum( selection_scores[domain].squeeze() for domain in three_domain_input ) @@ -79,42 +71,3 @@ def test_three_domains(): assert torch.allclose( scores_sum, expected_sum ), "Sum of selection scores across three domains should be 1" - - -def test_binary_scores_xor_check_for_multiple_proportions(): - temperature = 1.0 - domain_dim = 12 - batch_size = 2056 - num_tests = 10 # Number of random proportions to test - - for _ in range(num_tests): - binary_proportion = np.random.rand() # Random proportion between 0 and 1 - - selection = RandomSelection(binary_proportion, temperature) - domains_input = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - "audio": torch.rand(batch_size, domain_dim), - } - - prefusion_encodings = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - "audio": torch.rand(batch_size, domain_dim), - } - - selection_scores = selection(domains_input, prefusion_encodings) - - scores_matrix = torch.cat( - [selection_scores[domain] for domain in domains_input], dim=1 - ) - binary_scores_mask = scores_matrix == 1 - xor_binary_check = binary_scores_mask.sum(dim=1) == 1 - num_binary_rows = xor_binary_check.sum().item() - expected_num_binary_rows = int(batch_size * binary_proportion) - - assert num_binary_rows == expected_num_binary_rows, ( - "Incorrect number of binary score rows for proportion" - f"{binary_proportion:.2f}: expected {expected_num_binary_rows}, " - "got {num_binary_rows}" - )