Skip to content

Commit

Permalink
Docstring WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 1, 2024
1 parent c0e55a8 commit 85ba6fc
Show file tree
Hide file tree
Showing 10 changed files with 732 additions and 201 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
docs/
2 changes: 0 additions & 2 deletions docs/.gitignore

This file was deleted.

2 changes: 2 additions & 0 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspaceBase,
GWPredictions,
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace,
Expand Down Expand Up @@ -70,6 +71,7 @@
"GlobalWorkspaceBase",
"VariationalGlobalWorkspace",
"SchedulerArgs",
"GWPredictions",
"pretrained_global_workspace",
"RepeatedDataset",
]
6 changes: 3 additions & 3 deletions shimmer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class RepeatedDataset(Dataset):
"""
Dataset that cycles through its items to have a size of at least min size.
If drop_last is True, the size will be exaclty min_size. If drop_last is False,
the min_size <= size < min_size + len(dataset).
the min_size size < min_size + len(dataset).
"""

def __init__(
Expand All @@ -25,7 +25,7 @@ def __init__(
"""
Args:
dataset (SizedDataset): dataset to repeat. The dataset should have a size
(__len__ defined).
(where `__len__` is defined).
min_size (int): minimum size of the final dataset
drop_last (bool): whether to remove overflow when repeating the
dataset.
Expand All @@ -43,7 +43,7 @@ def __init__(
def __len__(self) -> int:
"""
Size of the dataset. Will be min_size if drop_last is True.
Otherwise, min_size <= size < min_size + len(dataset).
Otherwise, min_size size < min_size + len(dataset).
"""
return self.total_size

Expand Down
2 changes: 2 additions & 0 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspaceBase,
GWPredictions,
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace,
Expand Down Expand Up @@ -65,5 +66,6 @@
"GlobalWorkspaceBase",
"VariationalGlobalWorkspace",
"SchedulerArgs",
"GWPredictions",
"pretrained_global_workspace",
]
62 changes: 34 additions & 28 deletions shimmer/modules/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

ContrastiveLossType = Callable[[torch.Tensor, torch.Tensor], LossOutput]
"""Contrastive loss function type.
A function taking the prediction and targets and returning a LossOutput.
"""

VarContrastiveLossType = Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
]
"""Contrastive loss function type for variational GlobalWorkspace.
A function taking the prediction mean, prediction std, target mean and target std and
returns a LossOutput.
"""
Expand All @@ -31,10 +33,10 @@ def info_nce(
"""InfoNCE loss
Args:
x: prediction
y: target
logit_scale: logit scale
reduction: reduction to apply
x (`torch.Tensor`): prediction
y (`torch.Tensor`): target
logit_scale (`torch.Tensor`): logit scale
reduction (`Literal["mean", "sum", "none"]`): reduction to apply
Returns: the InfoNCE loss
"""
Expand All @@ -54,10 +56,10 @@ def contrastive_loss(
"""CLIP-like contrastive loss
Args:
x: prediction
y: target
logit_scale: logit scale
reduction: reduction to apply
x (`torch.Tensor`): prediction
y (`torch.Tensor`): target
logit_scale (`torch.Tensor`): logit scale
reduction (`Literal["mean", "sum", "none"]`): reduction to apply
Returns: the contrastive loss
"""
Expand All @@ -82,12 +84,12 @@ def contrastive_loss_with_uncertainty(
This is used in Variational Global Workspaces.
Args:
x: prediction
x_log_uncertainty: logvar of the prediction
y: target
y_log_uncertainty: logvar of the target
logit_scale: logit scale
reduction: reduction to apply
x (`torch.Tensor`): prediction
x_log_uncertainty (`torch.Tensor`): logvar of the prediction
y (`torch.Tensor`): target
y_log_uncertainty (`torch.Tensor`): logvar of the target
logit_scale (`torch.Tensor`): logit scale
reduction (`Literal["mean", "sum", "none"]`): reduction to apply
Returns: the contrastive loss with uncertainty.
"""
Expand All @@ -104,7 +106,7 @@ def contrastive_loss_with_uncertainty(


class ContrastiveLoss(torch.nn.Module):
"""CLIP-like ContrastiveLoss torch module"""
"""CLIP-like ContrastiveLoss torch module."""

def __init__(
self,
Expand All @@ -115,10 +117,11 @@ def __init__(
"""Initializes a contrastive loss.
Args:
logit_scale: logit_scale tensor.
reduction: reduction to apply to the loss. Defaults to "mean"
learn_logit_scale: whether to learn the logit_scale parameter. Defaults to
False.
logit_scale (`torch.Tensor`): logit_scale tensor.
reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
loss. Defaults to `"mean"`.
learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
parameter. Defaults to `False`.
"""
super().__init__()

Expand All @@ -130,10 +133,11 @@ def __init__(
self.reduction: Literal["mean", "sum", "none"] = reduction

def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
"""Computes the loss
"""Computes the loss.
Args:
x: prediction
y: target
x (`torch.Tensor`): prediction
y (`torch.Tensor`): target
Returns:
LossOutput of the loss. Contains a `logit_scale` metric.
Expand All @@ -146,6 +150,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:

class ContrastiveLossWithUncertainty(torch.nn.Module):
"""CLIP-like contrastive loss with uncertainty module.
This is used in Variational Global Workspaces.
"""

Expand All @@ -159,10 +164,11 @@ def __init__(
ContrastiveLoss used for VariationalGlobalWorkspace
Args:
logit_scale: logit_scale tensor.
reduction: reduction to apply to the loss. Defaults to "mean"
learn_logit_scale: whether to learn the logit_scale parameter. Defaults to
False.
logit_scale (`torch.Tensor`): logit_scale tensor.
reduction (`Literal["mean", "sum", "none"]`): reduction to apply to
the loss. Defaults to `"mean"`.
learn_logit_scale (`bool`): whether to learn the logit_scale parameter.
Defaults to `False`.
"""
super().__init__()

Expand All @@ -188,8 +194,8 @@ def forward(
y_log_uncertainty: target logvar
Returns:
LossOutput of the loss. Contains a `logit_scale` metric and a
`no_uncertainty` metric with the classic contrastive loss computed without
LossOutput of the loss. Contains a `"logit_scale"` metric and a
`"no_uncertainty"` metric with the classic contrastive loss computed without
the logvar information.
"""
return LossOutput(
Expand Down
Loading

0 comments on commit 85ba6fc

Please sign in to comment.