From ca6ab1c7f776c1b129d95c9eb818960d2a1d8948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= <42555442+kohlerca@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:19:55 +0200 Subject: [PATCH] [Fix] Error when tracking static method (#23) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixed error when a decorated method is static * Changed unit tests as applying the decorator with syntactic sugar does not work with the static method analysis * Unit test for tracking static methods --------- Co-authored-by: Cristiano Köhler --- alpaca/decorator.py | 18 ++++++++++++ alpaca/test/test_code_analysis.py | 6 ++-- alpaca/test/test_decorator.py | 47 +++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/alpaca/decorator.py b/alpaca/decorator.py index f502176..76cc3f2 100644 --- a/alpaca/decorator.py +++ b/alpaca/decorator.py @@ -320,6 +320,18 @@ def _is_class_constructor(function_name): names = function_name.split(".") return len(names) == 2 and names[-1] == "__init__" + @staticmethod + def _is_static_method(function, function_name): + if type(function).__qualname__ == "method_descriptor": + # Ignore method descriptors + return False + name = function_name.rsplit('.', 1)[-1] + cls = inspect._findclass(function) + if cls is not None: + method = inspect.getattr_static(cls, name) + return isinstance(method, staticmethod) + return False + def _capture_code_and_function_provenance(self, lineno, function): # 1. Capture Abstract Syntax Tree (AST) of the call to the @@ -648,6 +660,12 @@ def wrapped(*args, **kwargs): return function_output + # If the function is decorated with `staticmethod`, restore the + # decorator (otherwise `self` will be passed as first argument when + # calling the function) + if self._is_static_method(function, function.__qualname__): + return staticmethod(wrapped) + return wrapped @classmethod diff --git a/alpaca/test/test_code_analysis.py b/alpaca/test/test_code_analysis.py index 0060f20..f98a947 100644 --- a/alpaca/test/test_code_analysis.py +++ b/alpaca/test/test_code_analysis.py @@ -60,16 +60,14 @@ def __init__(self, array): # To test attribute calls class ObjectWithMethod: - @Provenance(inputs=['self', 'array']) def add_numbers(self, array): return np.sum(array) - +ObjectWithMethod.add_numbers = Provenance(inputs=['self', 'array'])(ObjectWithMethod.add_numbers) class CustomObject: - @Provenance(inputs=['data']) def __init__(self, data): self.data = data - +CustomObject.__init__ = Provenance(inputs=['data'])(CustomObject.__init__) # Define some test functions to use different relationships diff --git a/alpaca/test/test_decorator.py b/alpaca/test/test_decorator.py index de1769e..7c8ea89 100644 --- a/alpaca/test/test_decorator.py +++ b/alpaca/test/test_decorator.py @@ -121,14 +121,13 @@ def comprehension_function(param): class NonIterableContainerOutputObject(object): - - @Provenance(inputs=[], container_output=0) def __init__(self, start): self._data = np.arange(start+1, start+4) def __getitem__(self, item): return self._data[item] - +NonIterableContainerOutputObject.__init__ = \ + Provenance(inputs=[], container_output=0)(NonIterableContainerOutputObject.__init__) # Function to help verifying FunctionExecution tuples def _check_function_execution(actual, exp_function, exp_input, exp_params, @@ -1124,10 +1123,17 @@ def __init__(self, coefficient): def process(self, array, param1, param2): return array + self.coefficient + @staticmethod + def static_method(array, coefficient): + return array + coefficient + ObjectWithMethod.process = Provenance(inputs=['self', 'array'])( ObjectWithMethod.process) +ObjectWithMethod.static_method = Provenance(inputs=['array'])( + ObjectWithMethod.static_method) + # Apply decorator to method that uses the descriptor protocol neo.AnalogSignal.reshape = Provenance(inputs=[0])(neo.AnalogSignal.reshape) @@ -1137,6 +1143,41 @@ def process(self, array, param1, param2): class ProvenanceDecoratorClassMethodsTestCase(unittest.TestCase): + def test_static_method(self): + obj = ObjectWithMethod(2) + activate(clear=True) + res = obj.static_method(TEST_ARRAY, 4) + deactivate() + + self.assertEqual(len(Provenance.history), 1) + + obj_info = DataObject( + hash=joblib.hash(obj, hash_name='sha1'), + hash_method="joblib_SHA1", + type="test_decorator.ObjectWithMethod", + id=id(obj), + details={'coefficient': 2}) + + expected_output = DataObject( + hash=joblib.hash(TEST_ARRAY+4, hash_name='sha1'), + hash_method="joblib_SHA1", + type="numpy.ndarray", id=id(res), + details={'shape': (3,), 'dtype': np.int64}) + + _check_function_execution( + actual=Provenance.history[0], + exp_function=FunctionInfo('ObjectWithMethod.static_method', + 'test_decorator', ''), + exp_input={'array': TEST_ARRAY_INFO}, + exp_params={'coefficient': 4}, + exp_output={0: expected_output}, + exp_arg_map=['array', 'coefficient'], + exp_kwarg_map=[], + exp_code_stmnt="res = obj.static_method(TEST_ARRAY, 4)", + exp_return_targets=['res'], + exp_order=1, + test_case=self) + def test_method_descriptor(self): activate(clear=True) ansig = neo.AnalogSignal(TEST_ARRAY, units='mV', sampling_rate=1*pq.Hz)