Skip to content

Commit

Permalink
Update changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 22, 2024
1 parent e799073 commit 08169d6
Showing 4 changed files with 26 additions and 22 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -30,3 +30,9 @@ refers to `DeterministicGlobalWorkspace`.

# 0.4.0
* Use ABC for abstract methods.
* Replace `DomainDescription` with `GWInterface`.
* Add `contrastive_fn` attribute in `DeterministicGWLosses` to compute the contrastive loss.
It can then be customized.
* Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes
to remove "Deterministic".

15 changes: 7 additions & 8 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,12 @@
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace)
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
GWEncoder, GWInterface, GWInterfaceBase,
GWModule, VariationalGWEncoder,
from shimmer.modules.gw_module import (GWDecoder, GWEncoder, GWInterface,
GWInterfaceBase, GWModule, GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule)
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses,
VariationalGWLosses)
from shimmer.modules.losses import GWLosses, GWLossesBase, VariationalGWLosses
from shimmer.version import __version__

__all__ = [
@@ -22,16 +21,16 @@
"ShimmerInfoConfig",
"DomainModule",
"GWInterfaceBase",
"DeterministicGWModule",
"GWModule",
"GWDecoder",
"GWEncoder",
"GWInterface",
"GWModule",
"GWModuleBase",
"VariationalGWEncoder",
"VariationalGWInterface",
"VariationalGWModule",
"DeterministicGWLosses",
"GWLosses",
"GWLossesBase",
"VariationalGWLosses",
"GlobalWorkspace",
"GlobalWorkspaceBase",
15 changes: 7 additions & 8 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -4,27 +4,26 @@
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace)
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
GWEncoder, GWInterface, GWInterfaceBase,
GWModule, VariationalGWEncoder,
from shimmer.modules.gw_module import (GWDecoder, GWEncoder, GWInterface,
GWInterfaceBase, GWModule, GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule)
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses,
VariationalGWLosses)
from shimmer.modules.losses import GWLosses, GWLossesBase, VariationalGWLosses

__all__ = [
"DomainModule",
"GWInterfaceBase",
"DeterministicGWModule",
"GWModule",
"GWDecoder",
"GWEncoder",
"GWInterface",
"GWModule",
"GWModuleBase",
"VariationalGWEncoder",
"VariationalGWInterface",
"VariationalGWModule",
"DeterministicGWLosses",
"GWLosses",
"GWLossesBase",
"VariationalGWLosses",
"GlobalWorkspace",
"GlobalWorkspaceBase",
12 changes: 6 additions & 6 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

from shimmer.modules.dict_buffer import DictBuffer
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (DeterministicGWModule, GWModule,
from shimmer.modules.gw_module import (GWModule, GWModuleBase,
VariationalGWModule)
from shimmer.modules.vae import kl_divergence_loss

@@ -134,7 +134,7 @@ def step(


def _demi_cycle_loss(
gw_mod: GWModule,
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
latent_domains: LatentsT,
) -> dict[str, torch.Tensor]:
@@ -164,7 +164,7 @@ def _demi_cycle_loss(


def _cycle_loss(
gw_mod: GWModule,
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
latent_domains: LatentsT,
) -> dict[str, torch.Tensor]:
@@ -205,7 +205,7 @@ def _cycle_loss(


def _translation_loss(
gw_mod: GWModule,
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
latent_domains: LatentsT,
) -> dict[str, torch.Tensor]:
@@ -252,7 +252,7 @@ def _translation_loss(


def _contrastive_loss(
gw_mod: GWModule,
gw_mod: GWModuleBase,
latent_domains: LatentsT,
contrastive_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> dict[str, torch.Tensor]:
@@ -330,7 +330,7 @@ def _contrastive_loss_with_uncertainty(
class GWLosses(GWLossesBase):
def __init__(
self,
gw_mod: DeterministicGWModule,
gw_mod: GWModule,
domain_mods: dict[str, DomainModule],
coef_buffers: DictBuffer,
):

0 comments on commit 08169d6

Please sign in to comment.