diff --git a/shimmer/modules/attention_module.py b/shimmer/modules/attention_module.py index a0f89ad7..b5f19f31 100644 --- a/shimmer/modules/attention_module.py +++ b/shimmer/modules/attention_module.py @@ -204,7 +204,7 @@ def apply_two_sided_corruption( device = group_device(domains) batch_size = groups_batch_size(batch) n_domains = len(self.domain_names) - print(f"batch: {batch}") + corruption_matrices = {} for domain in range(n_domains): if self.fixed_corruption_vector is not None: @@ -239,14 +239,11 @@ def apply_two_sided_corruption( scaled_corruption_matrix ) - print(f"corruption_matrices: {corruption_matrices}") - for domain_names, domains in matched_data_dict.items(): if domain_names == self.domain_names: for domain_name, domain in domains.items(): if domain_name in corruption_matrices: domain += corruption_matrices[domain_name] - print(f"matched_data_dict: {matched_data_dict}") return matched_data_dict def calculate_mean_attention(