Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Bayesian models #164

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/q_and_a.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ To get insipiration, you can look at the source code of
## How can I change the loss function?
If you are using pre-made GW architecture
([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace),
[`GlobalWorkspaceBayesian`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceBayesian),
[`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss
used for demi-cycles, cycles, translations or broadcast, you can do so directly from
your definition of the
Expand Down
6 changes: 0 additions & 6 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from shimmer.modules.global_workspace import (
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
Expand All @@ -28,7 +27,6 @@
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
Expand All @@ -39,7 +37,6 @@
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
Expand Down Expand Up @@ -75,7 +72,6 @@
"SchedulerArgs",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
"DomainModule",
Expand All @@ -84,7 +80,6 @@
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"contrastive_loss",
Expand All @@ -93,7 +88,6 @@
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"batch_cycles",
"batch_demi_cycles",
Expand Down
8 changes: 0 additions & 8 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from shimmer.data.dataset import RepeatedDataset
from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossBayesianType,
ContrastiveLossType,
contrastive_loss,
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import (
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
Expand All @@ -23,7 +21,6 @@
GWEncoderLinear,
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
Expand All @@ -34,7 +31,6 @@
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
Expand All @@ -55,7 +51,6 @@
"SchedulerArgs",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
"DomainModule",
Expand All @@ -64,17 +59,14 @@
"GWEncoderLinear",
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"ContrastiveLossBayesianType",
"contrastive_loss",
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"reparameterize",
"kl_divergence_loss",
Expand Down
10 changes: 0 additions & 10 deletions shimmer/modules/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
A function taking the prediction and targets and returning a LossOutput.
"""

ContrastiveLossBayesianType = Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
]
"""
Contrastive loss function type for GlobalWorkspaceBayesian.
A function taking the prediction mean, prediction std, target mean and target std and
returns a LossOutput.
"""


def info_nce(
x: torch.Tensor,
Expand Down
104 changes: 0 additions & 104 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast_cycles,
cycle,
Expand All @@ -26,11 +25,9 @@
GWLosses,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
)
from shimmer.modules.selection import (
FixedSharedSelection,
RandomSelection,
SelectionBase,
SingleDomainSelection,
Expand Down Expand Up @@ -793,107 +790,6 @@ def __init__(
)


class GlobalWorkspaceBayesian(
GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
):
"""
A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
prediction.

This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
"""

def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
sensitivity_selection: float = 1,
sensitivity_precision: float = 1,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
use_normalized_constrastive: bool = True,
contrastive_loss: ContrastiveLossType | None = None,
precision_softmax_temp: float = 0.01,
scheduler: LRScheduler
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a Global Workspace

Args:
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
sensitivity_selection (`float`): sensivity coef $c'_1$
sensitivity_precision (`float`): sensitivity coef $c'_2$
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
learn_logit_scale (`bool`): whether to learn the contrastive learning
contrastive loss when using the default contrastive loss.
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
precision_softmax_temp (`float`): temperature to use in softmax of
precision
scheduler: The scheduler to use for traning. If None is explicitely given,
no scheduler will be used. Defaults to use OneCycleScheduler
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModuleBayesian(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
sensitivity_selection,
sensitivity_precision,
precision_softmax_temp,
)

selection_mod = FixedSharedSelection()

contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

loss_mod = GWLossesBayesian(
gw_mod,
selection_mod,
domain_mods,
loss_coefs,
contrastive_loss,
use_normalized_constrastive,
)

super().__init__(
gw_mod,
selection_mod,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
scheduler,
)


def pretrained_global_workspace(
checkpoint_path: str | Path,
domain_mods: Mapping[str, DomainModule],
Expand Down
Loading
Loading