Skip to content

Commit

Permalink
[Fix] Error when tracking static method (#23)
Browse files Browse the repository at this point in the history
* Fixed error when a decorated method is static

* Changed unit tests as applying the decorator with syntactic sugar does not work with the static method analysis

* Unit test for tracking static methods

---------

Co-authored-by: Cristiano Köhler <[email protected]>
  • Loading branch information
kohlerca and Cristiano Köhler authored Sep 26, 2023
1 parent 08a6dc2 commit ca6ab1c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
18 changes: 18 additions & 0 deletions alpaca/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ def _is_class_constructor(function_name):
names = function_name.split(".")
return len(names) == 2 and names[-1] == "__init__"

@staticmethod
def _is_static_method(function, function_name):
if type(function).__qualname__ == "method_descriptor":
# Ignore method descriptors
return False
name = function_name.rsplit('.', 1)[-1]
cls = inspect._findclass(function)
if cls is not None:
method = inspect.getattr_static(cls, name)
return isinstance(method, staticmethod)
return False

def _capture_code_and_function_provenance(self, lineno, function):

# 1. Capture Abstract Syntax Tree (AST) of the call to the
Expand Down Expand Up @@ -648,6 +660,12 @@ def wrapped(*args, **kwargs):

return function_output

# If the function is decorated with `staticmethod`, restore the
# decorator (otherwise `self` will be passed as first argument when
# calling the function)
if self._is_static_method(function, function.__qualname__):
return staticmethod(wrapped)

return wrapped

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions alpaca/test/test_code_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ def __init__(self, array):

# To test attribute calls
class ObjectWithMethod:
@Provenance(inputs=['self', 'array'])
def add_numbers(self, array):
return np.sum(array)

ObjectWithMethod.add_numbers = Provenance(inputs=['self', 'array'])(ObjectWithMethod.add_numbers)

class CustomObject:
@Provenance(inputs=['data'])
def __init__(self, data):
self.data = data

CustomObject.__init__ = Provenance(inputs=['data'])(CustomObject.__init__)

# Define some test functions to use different relationships

Expand Down
47 changes: 44 additions & 3 deletions alpaca/test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,13 @@ def comprehension_function(param):


class NonIterableContainerOutputObject(object):

@Provenance(inputs=[], container_output=0)
def __init__(self, start):
self._data = np.arange(start+1, start+4)

def __getitem__(self, item):
return self._data[item]

NonIterableContainerOutputObject.__init__ = \
Provenance(inputs=[], container_output=0)(NonIterableContainerOutputObject.__init__)

# Function to help verifying FunctionExecution tuples
def _check_function_execution(actual, exp_function, exp_input, exp_params,
Expand Down Expand Up @@ -1124,10 +1123,17 @@ def __init__(self, coefficient):
def process(self, array, param1, param2):
return array + self.coefficient

@staticmethod
def static_method(array, coefficient):
return array + coefficient


ObjectWithMethod.process = Provenance(inputs=['self', 'array'])(
ObjectWithMethod.process)

ObjectWithMethod.static_method = Provenance(inputs=['array'])(
ObjectWithMethod.static_method)

# Apply decorator to method that uses the descriptor protocol
neo.AnalogSignal.reshape = Provenance(inputs=[0])(neo.AnalogSignal.reshape)

Expand All @@ -1137,6 +1143,41 @@ def process(self, array, param1, param2):

class ProvenanceDecoratorClassMethodsTestCase(unittest.TestCase):

def test_static_method(self):
obj = ObjectWithMethod(2)
activate(clear=True)
res = obj.static_method(TEST_ARRAY, 4)
deactivate()

self.assertEqual(len(Provenance.history), 1)

obj_info = DataObject(
hash=joblib.hash(obj, hash_name='sha1'),
hash_method="joblib_SHA1",
type="test_decorator.ObjectWithMethod",
id=id(obj),
details={'coefficient': 2})

expected_output = DataObject(
hash=joblib.hash(TEST_ARRAY+4, hash_name='sha1'),
hash_method="joblib_SHA1",
type="numpy.ndarray", id=id(res),
details={'shape': (3,), 'dtype': np.int64})

_check_function_execution(
actual=Provenance.history[0],
exp_function=FunctionInfo('ObjectWithMethod.static_method',
'test_decorator', ''),
exp_input={'array': TEST_ARRAY_INFO},
exp_params={'coefficient': 4},
exp_output={0: expected_output},
exp_arg_map=['array', 'coefficient'],
exp_kwarg_map=[],
exp_code_stmnt="res = obj.static_method(TEST_ARRAY, 4)",
exp_return_targets=['res'],
exp_order=1,
test_case=self)

def test_method_descriptor(self):
activate(clear=True)
ansig = neo.AnalogSignal(TEST_ARRAY, units='mV', sampling_rate=1*pq.Hz)
Expand Down

0 comments on commit ca6ab1c

Please sign in to comment.