Initial release
Fix missing individual metrics for translation loss.
Fix wrong module used to compute the cycle losses. Don't do cycle with the same domain as target and source.
Add callback on_before_gw_encode and individual compute_losses for each loss type. Fix bugs
- Breaking change: remove
DeterministGlobaleWorkspace
andVariationalGlobalWorkspace
in favor of the functions:global_workspace
andvariational_global_workspace
. - Allow setting custom GW encoders and decoders.
- Breaking change: remove
self.input_dim
,self.encoder_hidden_dim
,self.encoder_n_layers
,self.decoder_hidden_dim
, andself.decoder_n_layers
inGWModule
s.
Fix bugs related to imports and default_decoders
.
-
Revert to using classes for GWs (it's easier when loading from checkpoints.)
-
GlobalWorkspace
is renamed toGlobalWorkspaceBase
andGlobalWorkspace
now refers toDeterministicGlobalWorkspace
.
- Use ABC for abstract methods.
- Replace
DomainDescription
withGWInterface
. - Add
contrastive_fn
attribute inDeterministicGWLosses
to compute the contrastive loss. It can then be customized. - Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes to remove "Deterministic".
- Remove all config related functions. This is not the role of this repo.
- Replace loss coef buffers by a
LossCoef
TypedDict. - Add
RepeatedDataset
to shimmer. - Add docs in
docs/
, API documentation in https://bdvllrs.github.io/shimmer/, and some code examples. - Replace Black, isort, and flake8 with Ruff (see #8).
- Remove
GWInterfaces
entirely and favor giving encoders and decoders directly to theGWModule
. See the updated exampleexamples/main_example/train_gw.py
to see what changes to make (see #9). - Remove
GWModuleBase.translate
andGWModuleBase.cycle
. Translation and cycles can now be done with the utils functiontranslation
andcycle
. - Remove
GlobalWorkspaceBase.batch_demi_cycles
,GlobalWorkspaceBase.batch_cycles
, andGlobalWorkspaceBase.batch_translations
. This can be done with utils functions of the same name. - Rename
GWModuleBase.fusion_mechanism
toGWModuleBase.fuse
,GWModuleBase.encode
toGWModuleBase.encode_and_fuse
, andGWModuleBase.encode_pre_fusion
toGWModuleBase.encode
. Same for the associated methods inGlobalWorkspaceBase
. - Remove on_before_gw_encode_{loss} callbacks to allow sharing computation between loss functions.
- Remove many _with_uncertainty functions. The GWModuleWithUncertainty now behaves like the other GWModules.
- Rename all "with_uncertainty" methods to "bayesian". Note, BayesianGlobalWorkspaces are still a work in progress.
- Added selection mechanisms (inheriting from
SelectionBase
, see docs) to fuse representations according to different mechanisms (e.g. Attention). GlobalWorkspace
(and associatedGWModule
,GWLosses
, ...) now uses theRandomSelection
mechanism. For the old behavior, useGlobalWorkspace2Domains
.