Skip to content

Commit

Permalink
Attention with scaling (#115)
Browse files Browse the repository at this point in the history
* add scaling as input to attentionbase

* delete print

* added fixed corruption vector

* add corruption per batch

* add prints for debugging

* add more prints

* hard code domain names to check bug

* change frozenset to list

* added more prints

* add print

* debug

* remove prints after debugging

* add prints for debugging

* add shape print

* change corruption vector

* print

* change dim

* change to cuda

* remove print

* prints

* check if loop works

* checking

* add other prints

* remove unnessary part of the loop

* add prints for debugging

* more prints

* change to 5

* remove prints

* change corruption from 5 to 25

* changed corruption to 5 again

* add prints for debugging

* more prints

* add more

* add more prints

* made another corruption function

* add prints

* change

* change

* remove prints in new corruption function

* add prints for debug

* change tensor size

* add more noise

* add prints

* changed corruption

* change corrruption

* remove prints

* change corruption

* change to old corruption

* made a new corruption functor on tensors

* debug

* add prints

* more prints

* change

* nog meer

* prints

* more prints

* prints

* add more prints

* changed function

* debug

* remove prints

* check how the attention scores change

* print attention dict

* prints

* change corruption to corrupt side as well

* add print to check

* change sides

* other change

* remove prints

* remove prints

* add attention scores to log

* add calculate mean attention

* print attention scores

* mean attention

* remove print attentions scores

* omkeren

* add variable corruption

* made a new function for testing all corruption combi

* found a bug in corruption scaling (default should be 0.0)

* remove prints

* add prints

* change corruption

* remove prints

* removed batch corruption function

* made a new function to train corruption on both sides

* remove prints for training

* bug

* added second step attention in query attention

* add corruption vector to device

* removed

* added step 2

* added even anoter step

* deleted added steps

* added multiple steps as a parameter for query attention

* found bug in variable corruption

* removed static attention and changed into step 0

* changed corruption functions to one and two sided corr

* spelling mistake

* again mistake

* remove prints

* Added some comments

* line was too long

* Fix mypy issues

---------

Co-authored-by: bdvllrs <[email protected]>
  • Loading branch information
larascipio and bdvllrs authored Jul 30, 2024
1 parent c7d5554 commit 774ff2f
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 91 deletions.
215 changes: 191 additions & 24 deletions shimmer/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RawDomainGroupsT,
RawDomainGroupT,
)
from shimmer.utils import group_device, groups_batch_size, groups_device


class AttentionBase(LightningModule):
Expand All @@ -41,6 +42,12 @@ def __init__(
criterion: Callable[
[torch.Tensor, RawDomainGroupT], tuple[torch.Tensor, torch.Tensor]
],
domain_dim: int,
fixed_corruption_vector: torch.Tensor | None = None,
corruption_scaling: list[float] | None = None,
corrupt_single_side: str | None = None,
corrupt_sides: bool = False,
two_sided_corruption: dict[str, float] | None = None,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -56,8 +63,15 @@ def __init__(

self.gw = gw
self.attention = attention
self.domain_names = domain_names
self.domain_names = frozenset(domain_names)
self.list_domain_names = list(domain_names)
self.criterion = criterion
self.domain_dim = domain_dim
self.fixed_corruption_vector = fixed_corruption_vector
self.corruption_scaling = corruption_scaling
self.corrupt_single_side = corrupt_single_side
self.corrupt_sides = corrupt_sides
self.test_sides_corruption = two_sided_corruption
self.optim_lr = optim_lr
self.optim_weight_decay = optim_weight_decay
self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
Expand Down Expand Up @@ -90,58 +104,208 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:

def forward(
self,
single_domain_input: LatentsDomainGroupsT,
corrupted_batch: LatentsDomainGroupsT,
prefusion_encodings: LatentsDomainGroupsT,
) -> LatentsDomainGroupsDT:
"""
Forward pass of the model.
Args:
corrupted_batch: The input to the model.
prefusion_encodings: The pre-fusion encodings.
Returns:
The attention scores.
"""
return {
domains: self.attention(latents, prefusion_encodings[domains])
for domains, latents in single_domain_input.items()
for domains, latents in corrupted_batch.items()
}

def apply_corruption(
def apply_one_sided_corruption(
self,
batch: LatentsDomainGroupsT,
corruption_vector: torch.Tensor | None = None,
) -> LatentsDomainGroupsDT:
"""
Apply corruption to the batch.
Apply corruption to each tensor of the matched data
by use of masking. Only for two domains.
Args:
batch: A batch of latent domains.
corruption_vector: A vector to be added to the corrupted domain.
corrupted_domain: The domain to be corrupted.
Returns:
A batch where one of the latent domains is corrupted.
A batch where either one (of the domains) of each tensor is corrupted.
"""
matched_data_dict: LatentsDomainGroupsDT = {}

# Make a copy of the batch
for domain_names, domains in batch.items():
for domain_name, domain in domains.items():
matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
continue
device = group_device(domains)
batch_size = groups_batch_size(batch)
n_domains = len(self.domain_names)

# Check if the side that should be corrupted is given
if self.corrupt_single_side is not None:
corrupted_domain_index = self.list_domain_names.index(
self.corrupt_single_side
)
masked_domains = torch.zeros(batch_size, n_domains, dtype=torch.bool)
masked_domains[:, corrupted_domain_index] = True
else:
selected_domains = torch.randint(0, n_domains, (batch_size,), device=device)
masked_domains = torch.nn.functional.one_hot(
selected_domains, n_domains
).to(device, torch.bool)

# Check if corruption is fixed or variable
if self.fixed_corruption_vector is not None:
corruption_vector = self.fixed_corruption_vector.expand(
batch_size, self.domain_dim
)
else:
corruption_vector = torch.randn(
(batch_size, self.domain_dim), device=device
)

# Normalize the corruption vector
corruption_vector = (
corruption_vector - corruption_vector.mean(dim=1, keepdim=True)
) / corruption_vector.std(dim=1, keepdim=True)

# Choose randomly from corruption scaling
amount_corruption = (
random.choice(self.corruption_scaling) if self.corruption_scaling else 1.0
)

# Scale the corruption vector based on the amount of corruption
scaled_corruption_vector = (corruption_vector * 5) * amount_corruption
for _, (domain_names, domains) in enumerate(matched_data_dict.items()):
if domain_names == self.domain_names:
for domain_name, domain in domains.items():
if domain_name == self.list_domain_names[0]:
domain[masked_domains[:, 0]] += scaled_corruption_vector[
masked_domains[:, 0]
]
if domain_name == self.list_domain_names[1]:
domain[~masked_domains[:, 0]] += scaled_corruption_vector[
~masked_domains[:, 0]
]
return matched_data_dict

def apply_two_sided_corruption(
self,
batch: LatentsDomainGroupsT,
) -> LatentsDomainGroupsDT:
"""
Apply corruption to each tensor of the matched data (two-sided corruption)
Only for two domains.
Args:
batch: A batch of latent domains.
Returns:
A batch where either both sides of the domains are corrupted.
"""
matched_data_dict: LatentsDomainGroupsDT = {}

# Make a copy of the batch
for domain_names, domains in batch.items():
# Randomly select a domain to be corrupted for this instance
corrupted_domain = random.choice(list(self.domain_names))
for domain_name, domain in domains.items():
if domain_names != self.domain_names or domain_name != corrupted_domain:
matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
continue
matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
continue
device = groups_device(batch)
batch_size = groups_batch_size(batch)
n_domains = len(self.domain_names)

corruption_matrices = {}

# Check if a fixed or variable corruption vector should be used
for domain_idx in range(n_domains):
if self.fixed_corruption_vector is not None:
corruption_matrix = self.fixed_corruption_vector.expand(
batch_size, self.domain_dim
).to(device)
else:
corruption_matrix = torch.randn(
(batch_size, self.domain_dim), device=device
)

# If corruption vector is not fixed outside the loop
if corruption_vector is None:
corruption_vector = torch.randn_like(domain)
# Normalize the corruption matrices
normalized_corruption_matrix = (
corruption_matrix - corruption_matrix.mean(dim=1, keepdim=True)
) / corruption_matrix.std(dim=1, keepdim=True)

# Apply element-wise addition to one of the domains
matched_data_dict.setdefault(domain_names, {})[domain_name] = (
domain + corruption_vector
# Get the scaled corruption vector
if self.test_sides_corruption is not None:
scaled_corruption_matrix = (
normalized_corruption_matrix * 5
) * self.test_sides_corruption[self.list_domain_names[domain_idx]]
else:
amount_corruption = (
random.choice(self.corruption_scaling)
if self.corruption_scaling
else 1.0
)
scaled_corruption_matrix = (
normalized_corruption_matrix * 5
) * amount_corruption
corruption_matrices[self.list_domain_names[domain_idx]] = (
scaled_corruption_matrix
)

for domain_names, domains in matched_data_dict.items():
if domain_names == self.domain_names:
for domain_name, domain in domains.items():
if domain_name in corruption_matrices:
domain += corruption_matrices[domain_name]
return matched_data_dict

def calculate_mean_attention(
self,
attention_scores: dict[frozenset[str], dict[str, Tensor]],
) -> dict:
"""
Calculate the mean attention scores for each domain.
Args:
attention_scores: The attention scores for each domain.
Returns:
The mean attention scores for each domain.
"""
# Initialize variables to accumulate mean scores
mean_attention_dict = {}

# Iterate through attention_dicts
for _, scores in attention_scores.items():
# Check if more than 1 domains are present
if len(scores) > 1:
for key, values in scores.items():
# Accumulate mean scores for each key
mean_score = values.mean().item()
mean_attention_dict[key] = mean_score
return mean_attention_dict

def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
"""
Generic step used by lightning, used for training, validation and testing.
Args:
batch: A batch of latent domains.
mode: The mode in which the model is currently in.
Returns:
The loss of the model.
"""
latent_domains = self.gw.encode_domains(batch)
corrupted_batch = self.apply_corruption(latent_domains)
if self.corrupt_sides is True:
corrupted_batch = self.apply_two_sided_corruption(latent_domains)
else:
corrupted_batch = self.apply_one_sided_corruption(latent_domains)
prefusion_encodings = self.gw.encode(corrupted_batch)
attention_scores = self.forward(corrupted_batch, prefusion_encodings)
merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores)

losses = []
accuracies = []

Expand All @@ -160,6 +324,10 @@ def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
accuracies[-1],
batch_size=domains.size(0),
)
mean_attention_scores = self.calculate_mean_attention(attention_scores)
for domain_name, score in mean_attention_scores.items():
self.log(f"{mode}/{domain_name}_mean_attention_score", score)

loss = torch.stack(losses).mean()
self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True)
self.log(f"{mode}/accuracy", torch.stack(accuracies).mean())
Expand Down Expand Up @@ -187,7 +355,6 @@ def test_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
"""Test step used by lightning"""

batch = {frozenset(data.keys()): data}
for domain in data:
batch[frozenset([domain])] = {domain: data[domain]}
Expand Down
Loading

0 comments on commit 774ff2f

Please sign in to comment.