Skip to content

Commit

Permalink
Update formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 8, 2024
1 parent c0f5170 commit 31de4cc
Showing 1 changed file with 17 additions and 88 deletions.
105 changes: 17 additions & 88 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
"""
return self.domain_mods[name].encode(domain)

def encode_domains(
self,
batch: RawDomainGroupsT,
) -> LatentsDomainGroupsDT:
def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
"""Encode all domains in the batch.
Args:
Expand Down Expand Up @@ -340,10 +337,7 @@ def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
"""
return self.domain_mods[name].decode(domain)

def decode_domains(
self,
latents_domain: LatentsDomainGroupsT,
) -> RawDomainGroupsDT:
def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
"""Decodes all domains in the batch.
Args:
Expand All @@ -362,10 +356,7 @@ def decode_domains(
for domains, latents in latents_domain.items()
}

def _get_batch_size(
self,
domain_latents: LatentsDomainGroupsT,
) -> int:
def _get_batch_size(self, domain_latents: LatentsDomainGroupsT) -> int:
"""Get the batch size of the batch.
Args:
Expand All @@ -379,11 +370,7 @@ def _get_batch_size(
return tensor.size(0)
raise ValueError("Empty batch.")

def generic_step(
self,
batch: RawDomainGroupsT,
mode: ModelModeT,
) -> torch.Tensor:
def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
"""The generic step used in `training_step`, `validation_step` and
`test_step`.
Expand Down Expand Up @@ -441,9 +428,7 @@ def training_step( # type: ignore
return self.generic_step(batch, mode="train")

def predict_step( # type: ignore
self,
data: Mapping[str, Any],
batch_idx: int,
self, data: Mapping[str, Any], batch_idx: int
) -> GWPredictions:
"""Predict step used by lightning"""

Expand Down Expand Up @@ -545,30 +530,14 @@ def __init__(
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)
loss_mod = GWLosses(
gw_mod,
domain_mods,
loss_coefs,
contrastive_loss,
)
loss_mod = GWLosses(gw_mod, domain_mods, loss_coefs, contrastive_loss)

super().__init__(
gw_mod,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)
super().__init__(gw_mod, loss_mod, optim_lr, optim_weight_decay, scheduler_args)


class GlobalWorkspaceWithUncertainty(GlobalWorkspaceBase):
Expand Down Expand Up @@ -624,10 +593,7 @@ def __init__(
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModuleWithUncertainty(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
domain_mods, workspace_dim, gw_encoders, gw_decoders
)

if use_var_contrastive_loss:
Expand All @@ -636,30 +602,18 @@ def __init__(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesWithUncertainty(
gw_mod,
domain_mods,
loss_coefs,
var_contrastive_fn=var_contrastive_loss,
gw_mod, domain_mods, loss_coefs, var_contrastive_fn=var_contrastive_loss
)
else:
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesWithUncertainty(
gw_mod,
domain_mods,
loss_coefs,
contrastive_fn=contrastive_loss,
gw_mod, domain_mods, loss_coefs, contrastive_fn=contrastive_loss
)

super().__init__(
gw_mod,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)
super().__init__(gw_mod, loss_mod, optim_lr, optim_weight_decay, scheduler_args)


class GlobalWorkspaceFusion(GlobalWorkspaceBase):
Expand Down Expand Up @@ -704,30 +658,15 @@ def __init__(
contrastive losses.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModuleFusion(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
)
gw_mod = GWModuleFusion(domain_mods, workspace_dim, gw_encoders, gw_decoders)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesFusion(
gw_mod,
domain_mods,
contrastive_loss,
)
loss_mod = GWLossesFusion(gw_mod, domain_mods, contrastive_loss)

super().__init__(
gw_mod,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)
super().__init__(gw_mod, loss_mod, optim_lr, optim_weight_decay, scheduler_args)


def pretrained_global_workspace(
Expand Down Expand Up @@ -769,18 +708,8 @@ def pretrained_global_workspace(
`TypeError`: if loaded type is not `GlobalWorkspace`.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders,
gw_decoders,
)
loss_mod = GWLosses(
gw_mod,
domain_mods,
loss_coefs,
contrastive_fn,
)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
loss_mod = GWLosses(gw_mod, domain_mods, loss_coefs, contrastive_fn)

gw = GlobalWorkspace.load_from_checkpoint(
checkpoint_path,
Expand Down

0 comments on commit 31de4cc

Please sign in to comment.