Skip to content

Commit

Permalink
Remove GWInterfaceBase, GWInterface, and co. (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Mar 6, 2024
1 parent 0dc8d1f commit c739207
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 313 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ refers to `DeterministicGlobalWorkspace`.
* Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes
to remove "Deterministic".
* Remove all config related functions. This is not the role of this repo.

# 0.4.1
* Remove `GWInterfaces` entirely and favor giving encoders and decoders directly to the
`GWModule`. See the updated example `examples/main_example/train_gw.py` to see what
changes to make.
Binary file modified docs/assets/shimmer_architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 36 additions & 29 deletions docs/shimmer_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ to make a GW in shimmer:
Let's detail:
- [`DomainModule`](https://bdvllrs.github.io/shimmer/shimmer.html#DomainModule)s
are the individual domain modules which encode domain data into a latent vector;
- `GWInterface`s are links to encode one domain in a GW representation;
- the `GWModule` has access to all `GWInterface`s and defines how to encode, decode and
merge representations of the domains into a unique GW representation.
- the `GWModule` has access to the domain modules, and defines how to encode, decode and merge representations of the domains into a unique GW representation.
- finally `GlobalWorkspaceBase` takes all building blocks to make a [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) module

The last building block (not in the diagram) is the `GWLosses` class which
Expand Down Expand Up @@ -418,8 +416,9 @@ from dataset import GWDataModule, get_domain_data, make_datasets
from domains import GenericDomain
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWInterface, LossCoefs
from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs


def train_gw():
Expand Down Expand Up @@ -451,20 +450,24 @@ def train_gw():

workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)

loss_coefs: LossCoefs = {
Expand All @@ -475,7 +478,7 @@ def train_gw():
}

global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down Expand Up @@ -535,24 +538,28 @@ This should be the same as what was used for the data.
}
```

We create the `GWInterfaces` to link the domain modules with the GlobalWorkspace
We define encoders and decoders to link the domain modules with the GlobalWorkspace
```python
workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)
```

Expand All @@ -570,7 +577,7 @@ We define loss coefficients for the different losses. Note that `LossCoefs` is a
Finally we make the GlobalWorkspace and train it.
```python
global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down
31 changes: 18 additions & 13 deletions examples/main_example/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from domains import GenericDomain
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWInterface, LossCoefs
from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs


def train_gw():
Expand Down Expand Up @@ -35,20 +36,24 @@ def train_gw():

workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)

loss_coefs: LossCoefs = {
Expand All @@ -59,7 +64,7 @@ def train_gw():
}

global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down
8 changes: 2 additions & 6 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWEncoderLinear,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -68,10 +66,8 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
8 changes: 2 additions & 6 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWEncoderLinear,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -54,10 +52,8 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
Loading

0 comments on commit c739207

Please sign in to comment.