Skip to content

Commit

Permalink
Attention module for query attention (#75)
Browse files Browse the repository at this point in the history
* Created class for lighting module

* changes to attention module

* create a test file for attention training

* fixed import error

* added argument to init attention_module

* added prints for testing

* debug corruption

* remove prints

* add ignore in attention init

* added other ignores

* first change frozenset into list

* add assert for gw size

* print domain size

* add domain names as input to criterion

* changed dim for dynamic attention

* rechange

* fixed bug in fusing weighted encodings

* debugging

* change initial layer not dependent on batch size

* remove print statements

* removed unnecessary variables

* docstrings added

* ruff fixes

* ruff

* mypy fix

* change loss to string

* requested changes from pr

* more changes for pr

* shape classifier to sequential

* but layers in init
  • Loading branch information
larascipio authored May 7, 2024
1 parent 55c3d29 commit 5a516f9
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 29 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python.analysis.typeCheckingMode": "basic",
"python.analysis.autoImportCompletions": true
}
205 changes: 205 additions & 0 deletions shimmer/modules/attention_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import random
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from torch import Tensor, nn
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.global_workspace import GlobalWorkspaceBase, SchedulerArgs
from shimmer.modules.gw import GWModuleBase
from shimmer.modules.losses import GWLossesBase
from shimmer.modules.selection import DynamicQueryAttention, SelectionBase
from shimmer.types import (
LatentsDomainGroupsDT,
LatentsDomainGroupsT,
RawDomainGroupsT,
RawDomainGroupT,
)


class ShapesClassifier(nn.Sequential):
def __init__(self, input_dim, output_dim):
layers = [
nn.Linear(input_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(64, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(32, output_dim),
]
super().__init__(*layers)


class DynamicAttention(LightningModule):
"""
Attention Lightning Module.
This is a wrapper around the DynamicQueryAttention module.
It is used to train the Dynamic Query Attention mechanism.
"""

def __init__(
self,
gw: GlobalWorkspaceBase[GWModuleBase, SelectionBase, GWLossesBase],
domain_dim: int,
head_size: int,
domain_names: Sequence[str],
criterion: Callable[[torch.Tensor, RawDomainGroupT], torch.Tensor],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
):
super().__init__()
self.save_hyperparameters(
ignore=[
"gw",
"criterion",
]
)

self.gw = gw
self.attention = DynamicQueryAttention(head_size, domain_dim, domain_names)
self.domain_names = domain_names
self.criterion = criterion
self.optim_lr = optim_lr
self.optim_weight_decay = optim_weight_decay
self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
if scheduler_args is not None:
self.scheduler_args.update(scheduler_args)

def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
"""
Configure models optimizers.
Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
scheduler.
"""

optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.optim_lr,
weight_decay=self.optim_weight_decay,
)

lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}

def forward(
self,
single_domain_input: LatentsDomainGroupsT,
prefusion_encodings: LatentsDomainGroupsT,
) -> LatentsDomainGroupsDT:
return {
domains: self.attention(latents, prefusion_encodings[domains])
for domains, latents in single_domain_input.items()
}

def apply_corruption(
self,
batch: LatentsDomainGroupsT,
corruption_vector: torch.Tensor | None = None,
corrupted_domain: str | None = None,
) -> LatentsDomainGroupsDT:
"""
Apply corruption to the batch.
Args:
batch: A batch of latent domains.
corruption_vector: A vector to be added to the corrupted domain.
corrupted_domain: The domain to be corrupted.
Returns:
A batch where one of the latent domains is corrupted.
"""
if corrupted_domain is None:
# Specify which domain will be corrupted
corrupted_domain = random.choice(list(self.domain_names))

matched_data_dict: LatentsDomainGroupsDT = {}
for domain_names, domains in batch.items():
for domain_name, domain in domains.items():
if domain_names != self.domain_names or domain_name != corrupted_domain:
matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
continue

# If corruption vector is not fixed outside the loop
if corruption_vector is None:
corruption_vector = torch.randn_like(domain)

# Apply element-wise addition to one of the domains
matched_data_dict.setdefault(domain_names, {})[domain_name] = (
domain + corruption_vector
)

return matched_data_dict

def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
latent_domains = self.gw.encode_domains(batch)
corrupted_batch = self.apply_corruption(latent_domains)
prefusion_encodings = self.gw.encode(corrupted_batch)
attention_scores = self.forward(corrupted_batch, prefusion_encodings)
merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores)
losses = []
for domain_names, domains in merged_gw_representation.items():
losses.append(self.criterion(domains, batch[domain_names]))
domain_names_str = ",".join(domain_names)
self.log(
f"{mode}/{domain_names_str}_loss",
losses[-1],
batch_size=domains.size(0),
)
loss = torch.stack(losses).mean()
print(f"loss: {loss}")
self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True)

return loss

def training_step(
self, batch: RawDomainGroupsT, batch_idx: int
) -> Tensor | Mapping[str, Any] | None: # type: ignore
return self.generic_step(batch, "train")

def validation_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
"""Validation step used by lightning"""

batch = {frozenset(data.keys()): data}
for domain in data:
batch[frozenset([domain])] = {domain: data[domain]}
if dataloader_idx == 0:
return self.generic_step(batch, mode="val")
return self.generic_step(batch, mode="val/ood")

def test_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
"""Test step used by lightning"""

batch = {frozenset(data.keys()): data}
for domain in data:
batch[frozenset([domain])] = {domain: data[domain]}
if dataloader_idx == 0:
return self.generic_step(batch, mode="test")
return self.generic_step(batch, mode="test/ood")
64 changes: 41 additions & 23 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def forward(

# Apply softmax across domains with temperature scaling
softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)

# Create attention dictionary for each domain
attention_dict = {
domain: softmax_scores[:, i] for i, domain in enumerate(domains)
Expand All @@ -213,32 +212,41 @@ def forward(
class DynamicQueryAttention(SelectionBase):
"""
Key-Query attention with a dynamic gw vector.
The query is updated based on the scaled gw vector.
"""

def __init__(
self, batch_size: int, domain_dim: int, head_size: int, domains: Iterable[str]
):
def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
"""
Args:
batch_size (`int`) : size of the batch
head_size (`int`) : dimension of the key and query vectors.
domain_dim (`int`) : dimension of the input dims (assumed to be the same
for now)
head_size (`int`) : dimension of the key and query vectors
domains (`Iterable[str]`) : list of input domains
domain_names (`Iterable[str]`) : list of input domains
"""
super().__init__()
self.batch_size = batch_size
self.head_size = head_size
self.query_layer = nn.Linear(domain_dim, head_size)
self.key_layers = nn.ModuleDict(
{domain: nn.Linear(domain_dim, head_size) for domain in domains}
{domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
)
# Start with a random gw state
self.gw_state = torch.rand(batch_size, domain_dim)
self.register_buffer("initial_gw_state", torch.rand(domain_dim))

def calculate_attention_dict(
self, keys: dict, query: torch.Tensor
self,
domains: LatentsDomainGroupT,
keys: dict[str, torch.Tensor],
query: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
keys (`dict[str, torch.Tensor]`): The keys for each domain.
query (`torch.Tensor`): The query tensor.
Returns:
`dict[str, torch.Tensor]`: The attention scores for each domain.
"""
dot_products = {
domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
for domain, key in keys.items()
Expand All @@ -249,26 +257,38 @@ def calculate_attention_dict(
attention_scores = torch.softmax(dot_products_tensor, dim=1)

attention_dict = {
domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys)
domain: attention_scores[:, i] for i, domain in enumerate(domains)
}
return attention_dict

def fuse_weighted_encodings(
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Fuse the weighted encodings using the attention scores.
Args:
encodings (`LatentsDomainGroupT`): Unimodal latent representation
attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
domain in the group.
Returns:
`torch.Tensor`: The fused tensor.
"""
# Apply attention scores to the encodings
weighted_encodings = {}
for key in attention_dict:
if key in encodings:
# Perform element-wise multiplication and store the result
weighted_encodings[key] = attention_dict[key] * encodings[key]
# Perform element-wise multiplication
weighted_encodings[key] = (
attention_dict[key].unsqueeze(1) * encodings[key]
)

# Stack the tensors along a new dimension (dimension 0)
stacked_tensors = torch.stack(list(weighted_encodings.values()))

# Apply fusion by summing along the newly created dimension
summed_tensor = torch.sum(stacked_tensors, dim=0)

return summed_tensor

def forward(
Expand All @@ -287,30 +307,28 @@ def forward(
group.
"""

# Encoding with pytorch
keys = {
domain: self.key_layers[domain](encoding)
for domain, encoding in domains.items()
}

# This for training (cpu or gpu)
device = group_device(domains)
batch_size = group_batch_size(domains)

# Retrieve query
query = self.query_layer(self.gw_state.to(device))
# Retrieve random query
query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))

# Calculate the attention scores
static_attention_dict = self.calculate_attention_dict(keys, query)
static_attention_dict = self.calculate_attention_dict(domains, keys, query)

# Apply the attention scores to the encodings
summed_tensor = self.fuse_weighted_encodings(
encodings_pre_fusion, static_attention_dict
)

# Retrieve query (now it is dependent on the new gw state)
query = self.query_layer(summed_tensor.to(device))
query = self.query_layer(summed_tensor)

# Calculate the attention scores again
dynamic_attention_dict = self.calculate_attention_dict(keys, query)
dynamic_attention_dict = self.calculate_attention_dict(domains, keys, query)

return dynamic_attention_dict
16 changes: 16 additions & 0 deletions shimmer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def groups_batch_size(domain_latents: LatentsDomainGroupsT) -> int:
raise ValueError("Empty batch.")


def groups_device(domain_latents: LatentsDomainGroupsT) -> int:
"""
Get the batch size of the batch.
Args:
domain_latents (`LatentsDomainGroupsT`): the batch of groups.
Returns:
int: the batch size.
"""
for data in domain_latents.values():
for tensor in data.values():
return tensor.size(0)
raise ValueError("Empty batch.")


def group_device(x: LatentsDomainGroupT) -> torch.device:
for val in x.values():
return val.device
Expand Down
Loading

0 comments on commit 5a516f9

Please sign in to comment.