Skip to content

Commit

Permalink
Remove Bayesian models (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Oct 4, 2024
1 parent 4e011e7 commit 0ee7a40
Show file tree
Hide file tree
Showing 8 changed files with 2 additions and 539 deletions.
1 change: 0 additions & 1 deletion docs/q_and_a.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ To get insipiration, you can look at the source code of
## How can I change the loss function?
If you are using pre-made GW architecture
([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace),
[`GlobalWorkspaceBayesian`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceBayesian),
[`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss
used for demi-cycles, cycles, translations or broadcast, you can do so directly from
your definition of the
Expand Down
6 changes: 0 additions & 6 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from shimmer.modules.global_workspace import (
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
Expand All @@ -28,7 +27,6 @@
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
Expand All @@ -39,7 +37,6 @@
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
Expand Down Expand Up @@ -75,7 +72,6 @@
"SchedulerArgs",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
"DomainModule",
Expand All @@ -84,7 +80,6 @@
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"contrastive_loss",
Expand All @@ -93,7 +88,6 @@
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"batch_cycles",
"batch_demi_cycles",
Expand Down
8 changes: 0 additions & 8 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from shimmer.data.dataset import RepeatedDataset
from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossBayesianType,
ContrastiveLossType,
contrastive_loss,
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import (
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
Expand All @@ -23,7 +21,6 @@
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
Expand All @@ -34,7 +31,6 @@
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
Expand All @@ -55,7 +51,6 @@
"SchedulerArgs",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
"DomainModule",
Expand All @@ -64,17 +59,14 @@
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"ContrastiveLossBayesianType",
"contrastive_loss",
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"reparameterize",
"kl_divergence_loss",
Expand Down
10 changes: 0 additions & 10 deletions shimmer/modules/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
A function taking the prediction and targets and returning a LossOutput.
"""

ContrastiveLossBayesianType = Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
]
"""
Contrastive loss function type for GlobalWorkspaceBayesian.
A function taking the prediction mean, prediction std, target mean and target std and
returns a LossOutput.
"""


def info_nce(
x: torch.Tensor,
Expand Down
104 changes: 0 additions & 104 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast_cycles,
cycle,
Expand All @@ -26,11 +25,9 @@
GWLosses,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
FixedSharedSelection,
RandomSelection,
SelectionBase,
SingleDomainSelection,
Expand Down Expand Up @@ -793,107 +790,6 @@ def __init__(
)


class GlobalWorkspaceBayesian(
GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
):
"""
A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
prediction.
This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
"""

def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
sensitivity_selection: float = 1,
sensitivity_precision: float = 1,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
use_normalized_constrastive: bool = True,
contrastive_loss: ContrastiveLossType | None = None,
precision_softmax_temp: float = 0.01,
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a Global Workspace
Args:
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
sensitivity_selection (`float`): sensivity coef $c'_1$
sensitivity_precision (`float`): sensitivity coef $c'_2$
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
learn_logit_scale (`bool`): whether to learn the contrastive learning
contrastive loss when using the default contrastive loss.
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
precision_softmax_temp (`float`): temperature to use in softmax of
precision
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModuleBayesian(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
sensitivity_selection,
sensitivity_precision,
precision_softmax_temp,
)

selection_mod = FixedSharedSelection()

contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

loss_mod = GWLossesBayesian(
gw_mod,
selection_mod,
domain_mods,
loss_coefs,
contrastive_loss,
use_normalized_constrastive,
)

super().__init__(
gw_mod,
selection_mod,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
scheduler,
)


def pretrained_global_workspace(
checkpoint_path: str | Path,
domain_mods: Mapping[str, DomainModule],
Expand Down
Loading

0 comments on commit 0ee7a40

Please sign in to comment.