Skip to content

Commit

Permalink
Unit test for tracking static methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Cristiano Köhler committed Sep 26, 2023
1 parent adf8a9f commit bbd8fc9
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions alpaca/test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,17 @@ def __init__(self, coefficient):
def process(self, array, param1, param2):
return array + self.coefficient

@staticmethod
def static_method(array, coefficient):
return array + coefficient


ObjectWithMethod.process = Provenance(inputs=['self', 'array'])(
ObjectWithMethod.process)

ObjectWithMethod.static_method = Provenance(inputs=['array'])(
ObjectWithMethod.static_method)

# Apply decorator to method that uses the descriptor protocol
neo.AnalogSignal.reshape = Provenance(inputs=[0])(neo.AnalogSignal.reshape)

Expand All @@ -1136,6 +1143,41 @@ def process(self, array, param1, param2):

class ProvenanceDecoratorClassMethodsTestCase(unittest.TestCase):

def test_static_method(self):
obj = ObjectWithMethod(2)
activate(clear=True)
res = obj.static_method(TEST_ARRAY, 4)
deactivate()

self.assertEqual(len(Provenance.history), 1)

obj_info = DataObject(
hash=joblib.hash(obj, hash_name='sha1'),
hash_method="joblib_SHA1",
type="test_decorator.ObjectWithMethod",
id=id(obj),
details={'coefficient': 2})

expected_output = DataObject(
hash=joblib.hash(TEST_ARRAY+4, hash_name='sha1'),
hash_method="joblib_SHA1",
type="numpy.ndarray", id=id(res),
details={'shape': (3,), 'dtype': np.int64})

_check_function_execution(
actual=Provenance.history[0],
exp_function=FunctionInfo('ObjectWithMethod.static_method',
'test_decorator', ''),
exp_input={'array': TEST_ARRAY_INFO},
exp_params={'coefficient': 4},
exp_output={0: expected_output},
exp_arg_map=['array', 'coefficient'],
exp_kwarg_map=[],
exp_code_stmnt="res = obj.static_method(TEST_ARRAY, 4)",
exp_return_targets=['res'],
exp_order=1,
test_case=self)

def test_method_descriptor(self):
activate(clear=True)
ansig = neo.AnalogSignal(TEST_ARRAY, units='mV', sampling_rate=1*pq.Hz)
Expand Down

0 comments on commit bbd8fc9

Please sign in to comment.