From a5f4c46727c51565b3eb9ca94ce111b8ed4816af Mon Sep 17 00:00:00 2001 From: larascipio Date: Wed, 17 Jul 2024 09:03:48 +0000 Subject: [PATCH] Added some comments --- shimmer/modules/attention_module.py | 41 +++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/shimmer/modules/attention_module.py b/shimmer/modules/attention_module.py index b5f19f31..3bc6529e 100644 --- a/shimmer/modules/attention_module.py +++ b/shimmer/modules/attention_module.py @@ -104,12 +104,22 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig: def forward( self, - single_domain_input: LatentsDomainGroupsT, + corrupted_batch: LatentsDomainGroupsT, prefusion_encodings: LatentsDomainGroupsT, ) -> LatentsDomainGroupsDT: + """ + Forward pass of the model. + + Args: + corrupted_batch: The input to the model. + prefusion_encodings: The pre-fusion encodings. + + Returns: + The attention scores. + """ return { domains: self.attention(latents, prefusion_encodings[domains]) - for domains, latents in single_domain_input.items() + for domains, latents in corrupted_batch.items() } def apply_one_sided_corruption( @@ -136,6 +146,7 @@ def apply_one_sided_corruption( batch_size = groups_batch_size(batch) n_domains = len(self.domain_names) + # Check if the side that should be corrupted is given if self.corrupt_single_side is not None: corrupted_domain_index = self.list_domain_names.index( self.corrupt_single_side @@ -148,6 +159,7 @@ def apply_one_sided_corruption( selected_domains, n_domains ).to(device, torch.bool) + # Check if corruption is fixed or variable if self.fixed_corruption_vector is not None: corruption_vector = self.fixed_corruption_vector.expand( batch_size, self.domain_dim @@ -166,6 +178,7 @@ def apply_one_sided_corruption( amount_corruption = ( random.choice(self.corruption_scaling) if self.corruption_scaling else 1.0 ) + # Scale the corruption vector based on the amount of corruption scaled_corruption_vector = (corruption_vector * 5) * amount_corruption for _, (domain_names, domains) in enumerate(matched_data_dict.items()): @@ -192,7 +205,7 @@ def apply_two_sided_corruption( Args: batch: A batch of latent domains. Returns: - A batch where either one (of the domains) of each tensor is corrupted. + A batch where either both sides of the domains are corrupted. """ matched_data_dict: LatentsDomainGroupsDT = {} @@ -206,6 +219,8 @@ def apply_two_sided_corruption( n_domains = len(self.domain_names) corruption_matrices = {} + + # Check if a fixed or variable corruption vector should be used for domain in range(n_domains): if self.fixed_corruption_vector is not None: corruption_matrix = self.fixed_corruption_vector.expand( @@ -221,7 +236,7 @@ def apply_two_sided_corruption( corruption_matrix - corruption_matrix.mean(dim=1, keepdim=True) ) / corruption_matrix.std(dim=1, keepdim=True) - # Scale the matrices + # Get the scaled corruption vector if self.test_sides_corruption is not None: scaled_corruption_matrix = ( normalized_corruption_matrix * 5 @@ -252,9 +267,16 @@ def calculate_mean_attention( ) -> dict: """ Calculate the mean attention scores for each domain. + + Args: + attention_scores: The attention scores for each domain. + + Returns: + The mean attention scores for each domain. """ # Initialize variables to accumulate mean scores mean_attention_dict = {} + # Iterate through attention_dicts for _, scores in attention_scores.items(): # Check if more than 1 domains are present @@ -266,6 +288,16 @@ def calculate_mean_attention( return mean_attention_dict def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor: + """ + Generic step used by lightning, used for training, validation and testing. + + Args: + batch: A batch of latent domains. + mode: The mode in which the model is currently in. + + Returns: + The loss of the model. + """ latent_domains = self.gw.encode_domains(batch) if self.corrupt_sides is True: corrupted_batch = self.apply_two_sided_corruption(latent_domains) @@ -323,7 +355,6 @@ def test_step( # type: ignore self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 ) -> torch.Tensor: """Test step used by lightning""" - batch = {frozenset(data.keys()): data} for domain in data: batch[frozenset([domain])] = {domain: data[domain]}