From 0ee7a40b1156d751ae56e3dc8455b48dd3cf45cb Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 4 Oct 2024 16:42:45 +0200 Subject: [PATCH] Remove Bayesian models (#164) --- docs/q_and_a.md | 1 - shimmer/__init__.py | 6 - shimmer/modules/__init__.py | 8 -- shimmer/modules/contrastive_loss.py | 10 -- shimmer/modules/global_workspace.py | 104 -------------- shimmer/modules/gw_module.py | 159 +--------------------- shimmer/modules/losses.py | 189 +------------------------- tests/test_with_confidence_modules.py | 64 --------- 8 files changed, 2 insertions(+), 539 deletions(-) delete mode 100644 tests/test_with_confidence_modules.py diff --git a/docs/q_and_a.md b/docs/q_and_a.md index a396f793..543e88e1 100644 --- a/docs/q_and_a.md +++ b/docs/q_and_a.md @@ -19,7 +19,6 @@ To get insipiration, you can look at the source code of ## How can I change the loss function? If you are using pre-made GW architecture ([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace), -[`GlobalWorkspaceBayesian`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceBayesian), [`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss used for demi-cycles, cycles, translations or broadcast, you can do so directly from your definition of the diff --git a/shimmer/__init__.py b/shimmer/__init__.py index c3197ac9..090c0b1c 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -14,7 +14,6 @@ from shimmer.modules.global_workspace import ( GlobalWorkspace2Domains, GlobalWorkspaceBase, - GlobalWorkspaceBayesian, SchedulerArgs, batch_broadcasts, batch_cycles, @@ -28,7 +27,6 @@ GWEncoderLinear, GWModule, GWModuleBase, - GWModuleBayesian, GWModulePrediction, broadcast, broadcast_cycles, @@ -39,7 +37,6 @@ BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, - GWLossesBayesian, LossCoefs, ) from shimmer.modules.selection import ( @@ -75,7 +72,6 @@ "SchedulerArgs", "GlobalWorkspaceBase", "GlobalWorkspace2Domains", - "GlobalWorkspaceBayesian", "pretrained_global_workspace", "LossOutput", "DomainModule", @@ -84,7 +80,6 @@ "GWEncoderLinear", "GWModuleBase", "GWModule", - "GWModuleBayesian", "GWModulePrediction", "ContrastiveLossType", "contrastive_loss", @@ -93,7 +88,6 @@ "BroadcastLossCoefs", "GWLossesBase", "GWLosses2Domains", - "GWLossesBayesian", "RepeatedDataset", "batch_cycles", "batch_demi_cycles", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 7d3a18ec..edffb919 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -1,7 +1,6 @@ from shimmer.data.dataset import RepeatedDataset from shimmer.modules.contrastive_loss import ( ContrastiveLoss, - ContrastiveLossBayesianType, ContrastiveLossType, contrastive_loss, ) @@ -9,7 +8,6 @@ from shimmer.modules.global_workspace import ( GlobalWorkspace2Domains, GlobalWorkspaceBase, - GlobalWorkspaceBayesian, SchedulerArgs, batch_broadcasts, batch_cycles, @@ -23,7 +21,6 @@ GWEncoderLinear, GWModule, GWModuleBase, - GWModuleBayesian, GWModulePrediction, broadcast, broadcast_cycles, @@ -34,7 +31,6 @@ BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, - GWLossesBayesian, LossCoefs, ) from shimmer.modules.selection import ( @@ -55,7 +51,6 @@ "SchedulerArgs", "GlobalWorkspaceBase", "GlobalWorkspace2Domains", - "GlobalWorkspaceBayesian", "pretrained_global_workspace", "LossOutput", "DomainModule", @@ -64,17 +59,14 @@ "GWEncoderLinear", "GWModuleBase", "GWModule", - "GWModuleBayesian", "GWModulePrediction", "ContrastiveLossType", - "ContrastiveLossBayesianType", "contrastive_loss", "ContrastiveLoss", "LossCoefs", "BroadcastLossCoefs", "GWLossesBase", "GWLosses2Domains", - "GWLossesBayesian", "RepeatedDataset", "reparameterize", "kl_divergence_loss", diff --git a/shimmer/modules/contrastive_loss.py b/shimmer/modules/contrastive_loss.py index ea520c01..d3194d26 100644 --- a/shimmer/modules/contrastive_loss.py +++ b/shimmer/modules/contrastive_loss.py @@ -15,16 +15,6 @@ A function taking the prediction and targets and returning a LossOutput. """ -ContrastiveLossBayesianType = Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput -] -""" -Contrastive loss function type for GlobalWorkspaceBayesian. - -A function taking the prediction mean, prediction std, target mean and target std and - returns a LossOutput. -""" - def info_nce( x: torch.Tensor, diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index ee03690a..1cec23ac 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -15,7 +15,6 @@ from shimmer.modules.gw_module import ( GWModule, GWModuleBase, - GWModuleBayesian, GWModulePrediction, broadcast_cycles, cycle, @@ -26,11 +25,9 @@ GWLosses, GWLosses2Domains, GWLossesBase, - GWLossesBayesian, LossCoefs, ) from shimmer.modules.selection import ( - FixedSharedSelection, RandomSelection, SelectionBase, SingleDomainSelection, @@ -793,107 +790,6 @@ def __init__( ) -class GlobalWorkspaceBayesian( - GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian] -): - """ - A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty - prediction. - - This is used to simplify a Global Workspace instanciation and only overrides the - `__init__` method. - """ - - def __init__( - self, - domain_mods: Mapping[str, DomainModule], - gw_encoders: Mapping[str, Module], - gw_decoders: Mapping[str, Module], - workspace_dim: int, - loss_coefs: BroadcastLossCoefs, - sensitivity_selection: float = 1, - sensitivity_precision: float = 1, - optim_lr: float = 1e-3, - optim_weight_decay: float = 0.0, - scheduler_args: SchedulerArgs | None = None, - learn_logit_scale: bool = False, - use_normalized_constrastive: bool = True, - contrastive_loss: ContrastiveLossType | None = None, - precision_softmax_temp: float = 0.01, - scheduler: LRScheduler - | None - | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, - ) -> None: - """ - Initializes a Global Workspace - - Args: - domain_mods (`Mapping[str, DomainModule]`): mapping of the domains - connected to the GW. Keys are domain names, values are the - `DomainModule`. - gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a `torch.nn.Module` class which role is to encode a - unimodal latent representations into a GW representation (pre fusion). - gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a `torch.nn.Module` class which role is to decode a - GW representation into a unimodal latent representations. - workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients - sensitivity_selection (`float`): sensivity coef $c'_1$ - sensitivity_precision (`float`): sensitivity coef $c'_2$ - optim_lr (`float`): learning rate - optim_weight_decay (`float`): weight decay - scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments - learn_logit_scale (`bool`): whether to learn the contrastive learning - contrastive loss when using the default contrastive loss. - use_normalized_constrastive (`bool`): whether to use the normalized cont - loss by the precision coefs - contrastive_loss (`ContrastiveLossType | None`): a contrastive loss - function used for alignment. `learn_logit_scale` will not affect custom - contrastive losses. - precision_softmax_temp (`float`): temperature to use in softmax of - precision - scheduler: The scheduler to use for traning. If None is explicitely given, - no scheduler will be used. Defaults to use OneCycleScheduler - """ - domain_mods = freeze_domain_modules(domain_mods) - - gw_mod = GWModuleBayesian( - domain_mods, - workspace_dim, - gw_encoders, - gw_decoders, - sensitivity_selection, - sensitivity_precision, - precision_softmax_temp, - ) - - selection_mod = FixedSharedSelection() - - contrastive_loss = ContrastiveLoss( - torch.tensor([1]).log(), "mean", learn_logit_scale - ) - - loss_mod = GWLossesBayesian( - gw_mod, - selection_mod, - domain_mods, - loss_coefs, - contrastive_loss, - use_normalized_constrastive, - ) - - super().__init__( - gw_mod, - selection_mod, - loss_mod, - optim_lr, - optim_weight_decay, - scheduler_args, - scheduler, - ) - - def pretrained_global_workspace( checkpoint_path: str | Path, domain_mods: Mapping[str, DomainModule], diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 489b8125..f7dedb5f 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping -from typing import TypedDict, cast +from typing import TypedDict import torch from torch import nn @@ -449,160 +449,3 @@ def decode( domain: self.gw_decoders[domain](z) for domain in domains or self.gw_decoders.keys() } - - -def compute_fusion_scores( - score_1: torch.Tensor, - score_2: torch.Tensor, - sensitivity_1: float = 1.0, - sensitivity_2: float = 1.0, - eps: float = 1e-6, -) -> torch.Tensor: - """ - Combine precision scores using std summation in quadrature - - The two scores should have the same dimension. - - Args: - score_1 (`torch.Tensor`): First scores. - score_2 (`torch.Tensor`): Second scores. - sensitivity_1 (`float`): sensitivity for the first score - sensitivity_2 (`float`): sensitivity for the second score - eps (`float`): a value added to avoid numerical unstability. - - Returns: - `torch.Tensor`: the combined scores - """ - total_uncertainty = sensitivity_1 / (eps + score_1) + sensitivity_2 / ( - eps + score_2 - ) - final_scores = 1 / (eps + total_uncertainty) - return final_scores / final_scores.sum(dim=0, keepdim=True) - - -class GWModuleBayesian(GWModule): - """`GWModule` with a Bayesian based uncertainty prediction.""" - - def __init__( - self, - domain_modules: Mapping[str, DomainModule], - workspace_dim: int, - gw_encoders: Mapping[str, nn.Module], - gw_decoders: Mapping[str, nn.Module], - sensitivity_selection: float = 1, - sensitivity_precision: float = 1, - precision_softmax_temp: float = 0.01, - ) -> None: - """ - Initializes the GWModuleBayesian. - - Args: - domain_modules (`Mapping[str, DomainModule]`): the domain modules. - workspace_dim (`int`): dimension of the GW. - gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a an torch.nn.Module class that encodes a - unimodal latent representations into a GW representation (pre fusion). - gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a an torch.nn.Module class that decodes a - GW representation to a unimodal latent representation. - sensitivity_selection (`float`): sensivity coef $c'_1$ - sensitivity_precision (`float`): sensitivity coef $c'_2$ - precision_softmax_temp (`float`): temperature to use in softmax of - precision - """ - super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders) - - self.precisions = cast( - dict[str, torch.Tensor], - nn.ParameterDict( - {domain: torch.randn(workspace_dim) for domain in gw_encoders} - ), - ) - """Precision at the neuron level for every domain.""" - - self.sensitivity_selection = sensitivity_selection - self.sensitivity_precision = sensitivity_precision - self.precision_softmax_temp = precision_softmax_temp - - def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor: - """ - Get the precision vector of given domain and batch - - Args: - domain (`str`): - x (`torch.Tensor`): batch of inputs - - Returns: - `torch.Tensor`: batch of precision - """ - return self.precisions[domain].unsqueeze(0).expand(x.size(0), -1) - - def fuse( - self, - x: LatentsDomainGroupT, - selection_scores: Mapping[str, torch.Tensor], - ) -> torch.Tensor: - """ - Merge function used to combine domains. - - In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the - dimension of the Global Workspace. - - This function needs to merge two kind of scores: - * the selection scores $a\\in [0,1]^{D\\times N}$; - * the precision scores $b \\in [0,1]^{D\\times N \\times d}$. - - .. note:: - The precision score is obtained by predicting logits and using a softmax - - We can obtain associated uncertainties to the scores by introducing a std - variable and using bayesian integration: - - $$a_k = \\frac{M_1}{\\sigma_k^2}$$ - where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$. - - Similarly, - $$b_k = \\frac{M_2}{\\mu_k^2}$$ - where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$. - - The we can sum the variances to obtain the final uncertainty (squared) $\\xi$: - $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$ - - which, in terms of $a_k$ and $b_k$ yields: - $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$ - where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$. - - Finally, the finale combined coefficient is - $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$ - where - $$M_3 = \\frac{1}{\\sum_{i=1}^D - \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$ - - Args: - x (`LatentsDomainGroupT`): the group of latent representation. - selection_score (`Mapping[str, torch.Tensor]`): attention scores to - use to encode the reprensetation. - Returns: - `torch.Tensor`: The merged representation. - """ - scores: list[torch.Tensor] = [] - precisions: list[torch.Tensor] = [] - domains: list[torch.Tensor] = [] - for domain, score in selection_scores.items(): - scores.append(score) - precisions.append(self.get_precision(domain, x[domain])) - domains.append(x[domain]) - combined_scores = compute_fusion_scores( - torch.stack(scores).unsqueeze(-1), - torch.softmax( - torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0 - ), - self.sensitivity_selection, - self.sensitivity_precision, - ) - return torch.tanh( - torch.sum( - combined_scores * torch.stack(domains), - dim=0, - ) - ) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 6ff9c116..d9c95499 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -7,11 +7,7 @@ from shimmer.modules.contrastive_loss import ContrastiveLossType from shimmer.modules.domain import DomainModule, LossOutput -from shimmer.modules.gw_module import ( - GWModule, - GWModuleBase, - GWModuleBayesian, -) +from shimmer.modules.gw_module import GWModule, GWModuleBase from shimmer.modules.selection import SelectionBase from shimmer.types import LatentsDomainGroupsT, ModelModeT, RawDomainGroupsT @@ -286,71 +282,6 @@ def contrastive_loss( return losses -def contrastive_loss_bayesian( - gw_mod: GWModuleBayesian, - latent_domains: LatentsDomainGroupsT, - contrastive_fn: ContrastiveLossType, -) -> dict[str, torch.Tensor]: - """ - Computes the contrastive loss with a Bayesian based uncertainty prediction. - - This return multiple metrics: - * `contrastive_{domain_1}_and_{domain_2}` with the contrastive - between 2 domains; - * `contrastive_{domain_1}_and_{domain_2}_{metric}` with - additional metrics provided by the domain_mod's - `compute_cont_loss` output; - * `contrastives` with the average value of all - `contrastive_{domain_1}_and_{domain_2}` values. - - Args: - gw_mod (`GWModuleBayesian`): The GWModule to use - latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups - contrastive_fn (`ContrastiveLossBayesianType`): the contrastive function - to apply - - Returns: - `dict[str, torch.Tensor]`: a dict of metrics. - """ - losses: dict[str, torch.Tensor] = {} - metrics: dict[str, torch.Tensor] = {} - keys: list[set[str]] = [] - - for latents in latent_domains.values(): - if len(latents) < 2: - continue - for domain1_name, domain1 in latents.items(): - z1 = gw_mod.encode({domain1_name: domain1})[domain1_name] - z1_precision = gw_mod.get_precision(domain1_name, domain1) - for domain2_name, domain2 in latents.items(): - selected_domains = {domain1_name, domain2_name} - if domain1_name == domain2_name or selected_domains in keys: - continue - - keys.append(selected_domains) - - loss_name = f"contrastive_{domain1_name}_and_{domain2_name}" - z2 = gw_mod.encode({domain2_name: domain2})[domain2_name] - z2_precision = gw_mod.get_precision(domain2_name, domain2) - coef = torch.softmax( - gw_mod.precision_softmax_temp - * torch.stack([z1_precision, z2_precision]), - dim=0, - ) - norm = torch.sqrt(coef[0] * coef[1]) - loss_output = contrastive_fn(z1 * norm, z2 * norm) - loss_output_no_norm = contrastive_fn(z1, z2) - losses[loss_name] = loss_output.loss - metrics.update( - {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()} - ) - metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss - - losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean() - losses.update(metrics) - return losses - - class LossCoefs(TypedDict, total=False): """ Dict of loss coefficients used in the GWLosses. @@ -804,121 +735,3 @@ def step( ).mean() return LossOutput(loss, metrics) - - -class GWLossesBayesian(GWLossesBase): - """ - Implementation of `GWLossesBase` used for `GWModuleBayesian`. - """ - - def __init__( - self, - gw_mod: GWModuleBayesian, - selection_mod: SelectionBase, - domain_mods: dict[str, DomainModule], - loss_coefs: BroadcastLossCoefs, - contrastive_fn: ContrastiveLossType, - use_normalized_constrastive: bool = True, - ): - """ - Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian - - Args: - gw_mod (`GWModuleBayesian`): the GWModule - selection_mod (`SelectionBase`): selection module - domain_mods (`dict[str, DomainModule]`): a dict where the key is the - domain name and value is the DomainModule - loss_coefs (`BroadcastLossCoefs`): loss coefficients - contrastive_fn (`ContrastiveLossType`): the contrastive function - to use in contrastive loss - use_normalized_constrastive (`bool`): whether to use the normalized cont - loss by the precision coefs - """ - super().__init__() - - self.gw_mod = gw_mod - """The GWModule.""" - - self.selection_mod = selection_mod - """Selection module""" - - self.domain_mods = domain_mods - """Domain modules linked to the GW.""" - - self.loss_coefs = loss_coefs - """The loss coefficients.""" - - self.contrastive_fn = contrastive_fn - """ - Contrastive loss to use. - """ - - self.use_normalized_constrastive = use_normalized_constrastive - - def contrastive_loss( - self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: - """ - Contrastive loss. - - Args: - latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups - - Returns: - `dict[str, torch.Tensor]`: a dict of metrics. - """ - if self.use_normalized_constrastive: - return contrastive_loss_bayesian( - self.gw_mod, latent_domains, self.contrastive_fn - ) - return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - - def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT - ) -> dict[str, torch.Tensor]: - return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data - ) - - def step( - self, - raw_data: RawDomainGroupsT, - domain_latents: LatentsDomainGroupsT, - mode: ModelModeT, - ) -> LossOutput: - """ - Performs a step of loss computation. - - Args: - raw_data (`RawDomainGroupsT`): raw input data - domain_latents: Latent representations for all domains. - mode: The mode in which the model is currently operating. - - Returns: - A LossOutput object containing the loss and metrics for this step. - """ - - metrics: dict[str, torch.Tensor] = {} - - metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents, raw_data)) - - loss = torch.stack( - [ - metrics[name] * coef - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and coef > 0 - ], - dim=0, - ).mean() - - metrics["broadcast_loss"] = torch.stack( - [ - metrics[name] - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and coef > 0 and name != "contrastives" - ], - dim=0, - ).mean() - - return LossOutput(loss, metrics) diff --git a/tests/test_with_confidence_modules.py b/tests/test_with_confidence_modules.py deleted file mode 100644 index afed19ca..00000000 --- a/tests/test_with_confidence_modules.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -from utils import DummyDomainModule - -from shimmer import GWDecoder, GWEncoder, GWModuleBayesian -from shimmer.modules.gw_module import compute_fusion_scores - - -def test_bayesian_fusion(): - domains = { - "v": DummyDomainModule(latent_dim=2), - "t": DummyDomainModule(latent_dim=4), - "a": DummyDomainModule(latent_dim=8), - } - - workspace_dim = 16 - - gw_encoders = { - domain_name: GWEncoder( - domain.latent_dim, - hidden_dim=64, - out_dim=workspace_dim, - n_layers=1, - ) - for domain_name, domain in domains.items() - } - - gw_decoders = { - domain_name: GWDecoder( - workspace_dim, - hidden_dim=64, - out_dim=domain.latent_dim, - n_layers=1, - ) - for domain_name, domain in domains.items() - } - - gw_module = GWModuleBayesian(domains, workspace_dim, gw_encoders, gw_decoders) - - batch_size = 32 - batch = { - domain_name: torch.randn(batch_size, domain.latent_dim) - for domain_name, domain in domains.items() - } - - pre_fusion_reps = gw_module.encode(batch) - selection_scores = { - domain: torch.full((batch_size,), 1.0 / 3.0) for domain in gw_encoders - } - scores: list[torch.Tensor] = [] - precisions: list[torch.Tensor] = [] - domains_: list[torch.Tensor] = [] - for domain, score in selection_scores.items(): - scores.append(score) - precisions.append(gw_module.get_precision(domain, pre_fusion_reps[domain])) - domains_.append(pre_fusion_reps[domain]) - combined_scores = compute_fusion_scores( - torch.stack(scores).unsqueeze(-1), - torch.softmax(torch.stack(precisions), dim=0), - 1, - 1, - ) - assert torch.allclose( - combined_scores.sum(dim=0), torch.ones_like(combined_scores.sum(dim=0)) - )