Skip to content

Commit

Permalink
Added some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
larascipio authored and bdvllrs committed Jul 30, 2024
1 parent c7bb498 commit 61d0c7e
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions shimmer/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,22 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:

def forward(
self,
single_domain_input: LatentsDomainGroupsT,
corrupted_batch: LatentsDomainGroupsT,
prefusion_encodings: LatentsDomainGroupsT,
) -> LatentsDomainGroupsDT:
"""
Forward pass of the model.
Args:
corrupted_batch: The input to the model.
prefusion_encodings: The pre-fusion encodings.
Returns:
The attention scores.
"""
return {
domains: self.attention(latents, prefusion_encodings[domains])
for domains, latents in single_domain_input.items()
for domains, latents in corrupted_batch.items()
}

def apply_one_sided_corruption(
Expand All @@ -136,6 +146,7 @@ def apply_one_sided_corruption(
batch_size = groups_batch_size(batch)
n_domains = len(self.domain_names)

# Check if the side that should be corrupted is given
if self.corrupt_single_side is not None:
corrupted_domain_index = self.list_domain_names.index(
self.corrupt_single_side
Expand All @@ -148,6 +159,7 @@ def apply_one_sided_corruption(
selected_domains, n_domains
).to(device, torch.bool)

# Check if corruption is fixed or variable
if self.fixed_corruption_vector is not None:
corruption_vector = self.fixed_corruption_vector.expand(
batch_size, self.domain_dim
Expand All @@ -166,6 +178,7 @@ def apply_one_sided_corruption(
amount_corruption = (
random.choice(self.corruption_scaling) if self.corruption_scaling else 1.0
)

# Scale the corruption vector based on the amount of corruption
scaled_corruption_vector = (corruption_vector * 5) * amount_corruption
for _, (domain_names, domains) in enumerate(matched_data_dict.items()):
Expand All @@ -192,7 +205,7 @@ def apply_two_sided_corruption(
Args:
batch: A batch of latent domains.
Returns:
A batch where either one (of the domains) of each tensor is corrupted.
A batch where either both sides of the domains are corrupted.
"""
matched_data_dict: LatentsDomainGroupsDT = {}

Expand All @@ -206,6 +219,8 @@ def apply_two_sided_corruption(
n_domains = len(self.domain_names)

corruption_matrices = {}

# Check if a fixed or variable corruption vector should be used
for domain in range(n_domains):
if self.fixed_corruption_vector is not None:
corruption_matrix = self.fixed_corruption_vector.expand(
Expand All @@ -221,7 +236,7 @@ def apply_two_sided_corruption(
corruption_matrix - corruption_matrix.mean(dim=1, keepdim=True)
) / corruption_matrix.std(dim=1, keepdim=True)

# Scale the matrices
# Get the scaled corruption vector
if self.test_sides_corruption is not None:
scaled_corruption_matrix = (
normalized_corruption_matrix * 5
Expand Down Expand Up @@ -252,9 +267,16 @@ def calculate_mean_attention(
) -> dict:
"""
Calculate the mean attention scores for each domain.
Args:
attention_scores: The attention scores for each domain.
Returns:
The mean attention scores for each domain.
"""
# Initialize variables to accumulate mean scores
mean_attention_dict = {}

# Iterate through attention_dicts
for _, scores in attention_scores.items():
# Check if more than 1 domains are present
Expand All @@ -266,6 +288,16 @@ def calculate_mean_attention(
return mean_attention_dict

def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
"""
Generic step used by lightning, used for training, validation and testing.
Args:
batch: A batch of latent domains.
mode: The mode in which the model is currently in.
Returns:
The loss of the model.
"""
latent_domains = self.gw.encode_domains(batch)
if self.corrupt_sides is True:
corrupted_batch = self.apply_two_sided_corruption(latent_domains)
Expand Down Expand Up @@ -323,7 +355,6 @@ def test_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
"""Test step used by lightning"""

batch = {frozenset(data.keys()): data}
for domain in data:
batch[frozenset([domain])] = {domain: data[domain]}
Expand Down

0 comments on commit 61d0c7e

Please sign in to comment.