Initial release
Fix missing individual metrics for translation loss.
Fix wrong module used to compute the cycle losses. Don't do cycle with the same domain as target and source.
Add callback on_before_gw_encode and individual compute_losses for each loss type. Fix bugs
- Breaking change: remove
DeterministGlobaleWorkspace
andVariationalGlobalWorkspace
in favor of the functions:global_workspace
andvariational_global_workspace
. - Allow setting custom GW encoders and decoders.
- Breaking change: remove
self.input_dim
,self.encoder_hidden_dim
,self.encoder_n_layers
,self.decoder_hidden_dim
, andself.decoder_n_layers
inGWModule
s.
Fix bugs related to imports and default_decoders
.
-
Revert to using classes for GWs (it's easier when loading from checkpoints.)
-
GlobalWorkspace
is renamed toGlobalWorkspaceBase
andGlobalWorkspace
now refers toDeterministicGlobalWorkspace
.
- Use ABC for abstract methods.
- Replace
DomainDescription
withGWInterface
. - Add
contrastive_fn
attribute inDeterministicGWLosses
to compute the contrastive loss. It can then be customized. - Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes to remove "Deterministic".
- Remove all config related functions. This is not the role of this repo.
- Replace loss coef buffers by a
LossCoef
TypedDict. - Add
RepeatedDataset
to shimmer. - Add docs in
docs/
, API documentation in https://ruflab.github.io/shimmer/, and some code examples. - Replace Black, isort, and flake8 with Ruff (see #8).
- Remove
GWInterfaces
entirely and favor giving encoders and decoders directly to theGWModule
. See the updated exampleexamples/main_example/train_gw.py
to see what changes to make (see #9). - Remove
GWModuleBase.translate
andGWModuleBase.cycle
. Translation and cycles can now be done with the utils functiontranslation
andcycle
. - Remove
GlobalWorkspaceBase.batch_demi_cycles
,GlobalWorkspaceBase.batch_cycles
, andGlobalWorkspaceBase.batch_translations
. This can be done with utils functions of the same name. - Rename
GWModuleBase.fusion_mechanism
toGWModuleBase.fuse
,GWModuleBase.encode
toGWModuleBase.encode_and_fuse
, andGWModuleBase.encode_pre_fusion
toGWModuleBase.encode
. Same for the associated methods inGlobalWorkspaceBase
. - Remove on_before_gw_encode_{loss} callbacks to allow sharing computation between loss functions.
- Remove many _with_uncertainty functions. The GWModuleWithUncertainty now behaves like the other GWModules.
- Rename all "with_uncertainty" methods to "bayesian". Note, BayesianGlobalWorkspaces are still a work in progress.
- Added selection mechanisms (inheriting from
SelectionBase
, see docs) to fuse representations according to different mechanisms (e.g. Attention). GlobalWorkspace
(and associatedGWModule
,GWLosses
, ...) now uses theRandomSelection
mechanism. For the old behavior, useGlobalWorkspace2Domains
.
- Allow for some domain modules to be trained end-to-end with the global workspace.
This brings some breaking changes:
DomainModule.compute_loss
andDomainModule.compute_*_loss
now require an 3rd parameterraw_target: Any
that stores the raw domain input (before being encoded). This is usefull for unimodal losses that require the actual inputs to compute the loss.GWLossesBase.step
requires a new first argumentraw_data: RawDomainGroupsT
to pass theraw_targets
to the domain modules.