From bbd8fc9d0543acd275550ec9ea170637007ac785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Tue, 26 Sep 2023 10:54:44 +0200 Subject: [PATCH] Unit test for tracking static methods --- alpaca/test/test_decorator.py | 42 +++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/alpaca/test/test_decorator.py b/alpaca/test/test_decorator.py index 40d2a74..7c8ea89 100644 --- a/alpaca/test/test_decorator.py +++ b/alpaca/test/test_decorator.py @@ -1123,10 +1123,17 @@ def __init__(self, coefficient): def process(self, array, param1, param2): return array + self.coefficient + @staticmethod + def static_method(array, coefficient): + return array + coefficient + ObjectWithMethod.process = Provenance(inputs=['self', 'array'])( ObjectWithMethod.process) +ObjectWithMethod.static_method = Provenance(inputs=['array'])( + ObjectWithMethod.static_method) + # Apply decorator to method that uses the descriptor protocol neo.AnalogSignal.reshape = Provenance(inputs=[0])(neo.AnalogSignal.reshape) @@ -1136,6 +1143,41 @@ def process(self, array, param1, param2): class ProvenanceDecoratorClassMethodsTestCase(unittest.TestCase): + def test_static_method(self): + obj = ObjectWithMethod(2) + activate(clear=True) + res = obj.static_method(TEST_ARRAY, 4) + deactivate() + + self.assertEqual(len(Provenance.history), 1) + + obj_info = DataObject( + hash=joblib.hash(obj, hash_name='sha1'), + hash_method="joblib_SHA1", + type="test_decorator.ObjectWithMethod", + id=id(obj), + details={'coefficient': 2}) + + expected_output = DataObject( + hash=joblib.hash(TEST_ARRAY+4, hash_name='sha1'), + hash_method="joblib_SHA1", + type="numpy.ndarray", id=id(res), + details={'shape': (3,), 'dtype': np.int64}) + + _check_function_execution( + actual=Provenance.history[0], + exp_function=FunctionInfo('ObjectWithMethod.static_method', + 'test_decorator', ''), + exp_input={'array': TEST_ARRAY_INFO}, + exp_params={'coefficient': 4}, + exp_output={0: expected_output}, + exp_arg_map=['array', 'coefficient'], + exp_kwarg_map=[], + exp_code_stmnt="res = obj.static_method(TEST_ARRAY, 4)", + exp_return_targets=['res'], + exp_order=1, + test_case=self) + def test_method_descriptor(self): activate(clear=True) ansig = neo.AnalogSignal(TEST_ARRAY, units='mV', sampling_rate=1*pq.Hz)