Skip to content

Commit

Permalink
Add broadcasts and cycled broadcast to GW forward (#104)
Browse files Browse the repository at this point in the history
* Add batch_broadcasts utils

* GlobalWorkspace forwards now return broadcasts

* Add forward for gw_mod

* fix: circular import

* Do broadcast cycles only if inverse is non empty
  • Loading branch information
bdvllrs authored Jun 28, 2024
1 parent 01f988c commit cd98e21
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 300 deletions.
1 change: 0 additions & 1 deletion docs/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"shimmer.modules.contrastive_loss",
"shimmer.dataset",
"shimmer.modules.vae",
"shimmer.modules.utils",
"shimmer.utils",
"shimmer.cli.ckpt_migration",
]
Expand Down
22 changes: 13 additions & 9 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
GWPredictions,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
batch_demi_cycles,
batch_translations,
pretrained_global_workspace,
)
from shimmer.modules.gw_module import (
Expand All @@ -26,6 +29,11 @@
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
cycle,
translation,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
Expand All @@ -39,13 +47,6 @@
SelectionBase,
SingleDomainSelection,
)
from shimmer.modules.utils import (
batch_cycles,
batch_demi_cycles,
batch_translations,
cycle,
translation,
)
from shimmer.types import (
LatentsDomainGroupDT,
LatentsDomainGroupsDT,
Expand All @@ -72,7 +73,6 @@
"RawDomainGroupT",
"ModelModeT",
"SchedulerArgs",
"GWPredictions",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
Expand All @@ -85,6 +85,7 @@
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"contrastive_loss",
"ContrastiveLoss",
Expand All @@ -97,6 +98,9 @@
"batch_cycles",
"batch_demi_cycles",
"batch_translations",
"batch_broadcasts",
"broadcast",
"broadcast_cycles",
"cycle",
"translation",
"MIGRATION_DIR",
Expand Down
22 changes: 13 additions & 9 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
GWPredictions,
SchedulerArgs,
batch_broadcasts,
batch_cycles,
batch_demi_cycles,
batch_translations,
pretrained_global_workspace,
)
from shimmer.modules.gw_module import (
Expand All @@ -21,6 +24,11 @@
GWModule,
GWModuleBase,
GWModuleBayesian,
GWModulePrediction,
broadcast,
broadcast_cycles,
cycle,
translation,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
Expand All @@ -34,13 +42,6 @@
SelectionBase,
SingleDomainSelection,
)
from shimmer.modules.utils import (
batch_cycles,
batch_demi_cycles,
batch_translations,
cycle,
translation,
)
from shimmer.modules.vae import (
VAE,
VAEDecoder,
Expand All @@ -52,7 +53,6 @@

__all__ = [
"SchedulerArgs",
"GWPredictions",
"GlobalWorkspaceBase",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
Expand All @@ -65,6 +65,7 @@
"GWModuleBase",
"GWModule",
"GWModuleBayesian",
"GWModulePrediction",
"ContrastiveLossType",
"ContrastiveLossBayesianType",
"contrastive_loss",
Expand All @@ -84,6 +85,9 @@
"batch_cycles",
"batch_demi_cycles",
"batch_translations",
"batch_broadcasts",
"broadcast",
"broadcast_cycles",
"cycle",
"translation",
"RandomSelection",
Expand Down
Loading

0 comments on commit cd98e21

Please sign in to comment.