Skip to content

Commit

Permalink
fix hashability of Custom functions
Browse files Browse the repository at this point in the history
This patch changes evalf and partial_derivative of function.Custom and unit
test subclasses to classmethod or staticmethod to avoid hashability issues.
It also updates the documentation to make this change mandatory.
  • Loading branch information
gertjanvanzwieten committed Nov 5, 2021
1 parent 0018abf commit a9d7880
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
19 changes: 15 additions & 4 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,18 @@ class Custom(Array):
pointwise axes and the shape of the result passed to :class:`Custom` should
not include the pointwise axes.
For internal reasons, both ``evalf`` and ``partial_derivative`` must be
decorated as ``classmethod`` or ``staticmethod``, meaning that they will not
receive a reference to ``self`` when called. Instead, all relevant data
should be passed to ``evalf`` via the constructor argument ``args``. The
constructor will automatically distinguish between Array and non-Array
arguments, and pass the latter on to ``evalf`` unchanged. The
``partial_derivative`` will not be called for those arguments.
The lowered array does not have a Nutils hash by default. If this is desired,
the methods :meth:`evalf` and :meth:`partial_derivative` can be decorated
with :func:`nutils.types.hashable_function`.
with :func:`nutils.types.hashable_function` in addition to ``classmethod`` or
``staticmethod``.
Parameters
----------
Expand Down Expand Up @@ -620,7 +629,8 @@ def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMa
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in coordinates.items()}
return _CustomEvaluable(type(self).__name__, self.evalf, self.partial_derivative, args, self.shape[self._npointwise:], self.dtype, self.spaces, types.frozendict(self.arguments), points_shape, tuple(transform_chains.items()), tuple(coordinates.items()))

def evalf(self, *args: Any) -> numpy.ndarray:
@classmethod
def evalf(cls, *args: Any) -> numpy.ndarray:
'''Evaluate this function for the given evaluated arguments.
This function is called with arguments that correspond to the arguments
Expand Down Expand Up @@ -654,7 +664,8 @@ def evalf(self, *args: Any) -> numpy.ndarray:

raise NotImplementedError # pragma: nocover

def partial_derivative(self, iarg: int, *args: Any) -> IntoArray:
@classmethod
def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray:
'''Return the partial derivative of this function to :class:`Custom` constructor argument number ``iarg``.
This method is only called for those arguments that are instances of
Expand All @@ -678,7 +689,7 @@ def partial_derivative(self, iarg: int, *args: Any) -> IntoArray:
The partial derivative of this function to the given argument.
'''

raise NotImplementedError('The partial derivative of {} to argument {} (counting from 0) is not defined.'.format(type(self).__name__, iarg)) # pragma: nocover
raise NotImplementedError('The partial derivative of {} to argument {} (counting from 0) is not defined.'.format(cls.__name__, iarg)) # pragma: nocover

class _CustomEvaluable(evaluable.Array):

Expand Down
17 changes: 10 additions & 7 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ class Func(function.Custom):
def __init__(self):
super().__init__(args=(), shape=(3,), dtype=int)

def evalf(self):
@staticmethod
def evalf():
return numpy.array([1,2,3])[None]

self.assertEvalAlmostEqual(Func(), function.Array.cast([1,2,3]))
Expand All @@ -371,10 +372,12 @@ def __init__(self, offset, base1, exp1, base2, exp2):
assert base1.shape == base2.shape
super().__init__(args=(offset, base1, exp1.__index__(), base2, exp2.__index__()), shape=base1.shape, dtype=float)

def evalf(self, offset, base1, exp1, base2, exp2):
@staticmethod
def evalf(offset, base1, exp1, base2, exp2):
return offset + base1**exp1 + base2**exp2

def partial_derivative(self, iarg, offset, base1, exp1, base2, exp2):
@staticmethod
def partial_derivative(iarg, offset, base1, exp1, base2, exp2):
if iarg == 1:
if exp1 == 0:
return function.zeros(base1.shape + base1.shape)
Expand Down Expand Up @@ -403,21 +406,21 @@ def test_deduplication(self):
class A(function.Custom):

@staticmethod
def evalf(self):
def evalf():
pass

@staticmethod
def partial_derivative(self, iarg):
def partial_derivative(iarg):
pass

class B(function.Custom):

@staticmethod
def evalf(self):
def evalf():
pass

@staticmethod
def partial_derivative(self, iarg):
def partial_derivative(iarg):
pass

a = A(args=(function.Argument('a', (2,3)),), shape=(), dtype=float).as_evaluable_array
Expand Down

0 comments on commit a9d7880

Please sign in to comment.