Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove GWInterfaceBase, GWInterface, and co. #9

Merged
merged 8 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ refers to `DeterministicGlobalWorkspace`.
* Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes
to remove "Deterministic".
* Remove all config related functions. This is not the role of this repo.

# 0.4.1
* Remove `GWInterfaces` entirely and favor giving encoders and decoders directly to the
`GWModule`. See the updated example `examples/main_example/train_gw.py` to see what
changes to make.
Binary file modified docs/assets/shimmer_architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 36 additions & 29 deletions docs/shimmer_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ to make a GW in shimmer:
Let's detail:
- [`DomainModule`](https://bdvllrs.github.io/shimmer/shimmer.html#DomainModule)s
are the individual domain modules which encode domain data into a latent vector;
- `GWInterface`s are links to encode one domain in a GW representation;
- the `GWModule` has access to all `GWInterface`s and defines how to encode, decode and
merge representations of the domains into a unique GW representation.
- the `GWModule` has access to the domain modules, and defines how to encode, decode and merge representations of the domains into a unique GW representation.
- finally `GlobalWorkspaceBase` takes all building blocks to make a [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) module

The last building block (not in the diagram) is the `GWLosses` class which
Expand Down Expand Up @@ -418,8 +416,9 @@ from dataset import GWDataModule, get_domain_data, make_datasets
from domains import GenericDomain
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWInterface, LossCoefs
from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs


def train_gw():
Expand Down Expand Up @@ -451,20 +450,24 @@ def train_gw():

workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)

loss_coefs: LossCoefs = {
Expand All @@ -475,7 +478,7 @@ def train_gw():
}

global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down Expand Up @@ -535,24 +538,28 @@ This should be the same as what was used for the data.
}
```

We create the `GWInterfaces` to link the domain modules with the GlobalWorkspace
We define encoders and decoders to link the domain modules with the GlobalWorkspace
```python
workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)
```

Expand All @@ -570,7 +577,7 @@ We define loss coefficients for the different losses. Note that `LossCoefs` is a
Finally we make the GlobalWorkspace and train it.
```python
global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down
31 changes: 18 additions & 13 deletions examples/main_example/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from domains import GenericDomain
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWInterface, LossCoefs
from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs


def train_gw():
Expand Down Expand Up @@ -35,20 +36,24 @@ def train_gw():

workspace_dim = 16

# Now we define interfaces that will encode and decode the domain representations
# to and from the global workspace
# We will use the already defined GWInterface class
gw_interfaces: dict[str, GWInterface] = {}
# Now we define modality encoders and decoders that will encode and decode
# the domain representations to and from the global workspace
gw_encoders: dict[str, nn.Module] = {}
gw_decoders: dict[str, nn.Module] = {}
for name, mod in domain_mods.items():
gw_interfaces[name] = GWInterface(
mod,
workspace_dim,
encoder_hidden_dim=64,
gw_encoders[name] = GWEncoder(
mod.latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
# total number of Linear layers is this value + 2 (one before, one after)
encoder_n_layers=1,
decoder_hidden_dim=64,
n_layers=1,
)
gw_decoders[name] = GWDecoder(
in_dim=workspace_dim,
hidden_dim=64,
out_dim=mod.latent_dim,
# total number of Linear layers is this value + 2 (one before, one after)
decoder_n_layers=1,
n_layers=1,
)

loss_coefs: LossCoefs = {
Expand All @@ -59,7 +64,7 @@ def train_gw():
}

global_workspace = GlobalWorkspace(
domain_mods, gw_interfaces, workspace_dim, loss_coefs
domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs
)

trainer = Trainer(
Expand Down
8 changes: 2 additions & 6 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWEncoderLinear,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -68,10 +66,8 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
8 changes: 2 additions & 6 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWEncoderLinear,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -54,10 +52,8 @@
"DomainModule",
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
Loading
Loading