From e597314e9bd2ef00ae1084310b48f4877739f981 Mon Sep 17 00:00:00 2001 From: Christy Sauper Date: Wed, 30 Oct 2024 13:55:00 -0700 Subject: [PATCH] Fix assorted unbound variables [4/n] (#1365) Summary: Fix unbound variables that flake8 is complaining about Reviewed By: cyrjano Differential Revision: D64261231 --- tests/attr/test_class_summarizer.py | 4 +++- tests/attr/test_input_layer_wrapper.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/attr/test_class_summarizer.py b/tests/attr/test_class_summarizer.py index 725358616..86f9a9f1f 100644 --- a/tests/attr/test_class_summarizer.py +++ b/tests/attr/test_class_summarizer.py @@ -57,7 +57,9 @@ def test_classes(self) -> None: for batch_size in [None, 1, 4]: for sizes, classes in zip(sizes_to_test, list_of_classes): - def create_batch_labels(batch_idx): + def create_batch_labels( + batch_idx, batch_size=batch_size, classes=classes + ): if batch_size is None: # batch_size = 1 return classes[batch_idx] diff --git a/tests/attr/test_input_layer_wrapper.py b/tests/attr/test_input_layer_wrapper.py index d768253c0..053858c23 100644 --- a/tests/attr/test_input_layer_wrapper.py +++ b/tests/attr/test_input_layer_wrapper.py @@ -45,6 +45,7 @@ class InputLayerMeta(type): def __new__(metacls, name: str, bases: Tuple, attrs: Dict): + global layer_methods_to_test_with_equiv for ( layer_method, equiv_method, @@ -56,7 +57,7 @@ def __new__(metacls, name: str, bases: Tuple, attrs: Dict): + f"_{equiv_method.__name__}_{multi_layer}" ) attrs[test_name] = ( - lambda self: self.layer_method_with_input_layer_patches( + lambda self, layer_method=layer_method, equiv_method=equiv_method, multi_layer=multi_layer: self.layer_method_with_input_layer_patches( # noqa: E501 layer_method, equiv_method, multi_layer ) ) @@ -107,8 +108,14 @@ def layer_method_with_input_layer_patches( real_attributions = equivalent_method.attribute(*args_to_use, target=0) - if not isinstance(a1, tuple): + if isinstance(a1, list): + a1 = tuple(a1) + elif not isinstance(a1, tuple): a1 = (a1,) + + if isinstance(a2, list): + a2 = tuple(a2) + elif not isinstance(a2, tuple): a2 = (a2,) if not isinstance(real_attributions, tuple):