From a9d7880d90bd3c9655ab66ec54eafc4b6adf1577 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Oct 2021 14:22:23 +0200 Subject: [PATCH] fix hashability of Custom functions This patch changes evalf and partial_derivative of function.Custom and unit test subclasses to classmethod or staticmethod to avoid hashability issues. It also updates the documentation to make this change mandatory. --- nutils/function.py | 19 +++++++++++++++---- tests/test_function.py | 17 ++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/nutils/function.py b/nutils/function.py index b1ea6cce9..55ec94546 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -494,9 +494,18 @@ class Custom(Array): pointwise axes and the shape of the result passed to :class:`Custom` should not include the pointwise axes. + For internal reasons, both ``evalf`` and ``partial_derivative`` must be + decorated as ``classmethod`` or ``staticmethod``, meaning that they will not + receive a reference to ``self`` when called. Instead, all relevant data + should be passed to ``evalf`` via the constructor argument ``args``. The + constructor will automatically distinguish between Array and non-Array + arguments, and pass the latter on to ``evalf`` unchanged. The + ``partial_derivative`` will not be called for those arguments. + The lowered array does not have a Nutils hash by default. If this is desired, the methods :meth:`evalf` and :meth:`partial_derivative` can be decorated - with :func:`nutils.types.hashable_function`. + with :func:`nutils.types.hashable_function` in addition to ``classmethod`` or + ``staticmethod``. Parameters ---------- @@ -620,7 +629,8 @@ def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMa coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in coordinates.items()} return _CustomEvaluable(type(self).__name__, self.evalf, self.partial_derivative, args, self.shape[self._npointwise:], self.dtype, self.spaces, types.frozendict(self.arguments), points_shape, tuple(transform_chains.items()), tuple(coordinates.items())) - def evalf(self, *args: Any) -> numpy.ndarray: + @classmethod + def evalf(cls, *args: Any) -> numpy.ndarray: '''Evaluate this function for the given evaluated arguments. This function is called with arguments that correspond to the arguments @@ -654,7 +664,8 @@ def evalf(self, *args: Any) -> numpy.ndarray: raise NotImplementedError # pragma: nocover - def partial_derivative(self, iarg: int, *args: Any) -> IntoArray: + @classmethod + def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray: '''Return the partial derivative of this function to :class:`Custom` constructor argument number ``iarg``. This method is only called for those arguments that are instances of @@ -678,7 +689,7 @@ def partial_derivative(self, iarg: int, *args: Any) -> IntoArray: The partial derivative of this function to the given argument. ''' - raise NotImplementedError('The partial derivative of {} to argument {} (counting from 0) is not defined.'.format(type(self).__name__, iarg)) # pragma: nocover + raise NotImplementedError('The partial derivative of {} to argument {} (counting from 0) is not defined.'.format(cls.__name__, iarg)) # pragma: nocover class _CustomEvaluable(evaluable.Array): diff --git a/tests/test_function.py b/tests/test_function.py index ff299f0ae..79ce97f0c 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -356,7 +356,8 @@ class Func(function.Custom): def __init__(self): super().__init__(args=(), shape=(3,), dtype=int) - def evalf(self): + @staticmethod + def evalf(): return numpy.array([1,2,3])[None] self.assertEvalAlmostEqual(Func(), function.Array.cast([1,2,3])) @@ -371,10 +372,12 @@ def __init__(self, offset, base1, exp1, base2, exp2): assert base1.shape == base2.shape super().__init__(args=(offset, base1, exp1.__index__(), base2, exp2.__index__()), shape=base1.shape, dtype=float) - def evalf(self, offset, base1, exp1, base2, exp2): + @staticmethod + def evalf(offset, base1, exp1, base2, exp2): return offset + base1**exp1 + base2**exp2 - def partial_derivative(self, iarg, offset, base1, exp1, base2, exp2): + @staticmethod + def partial_derivative(iarg, offset, base1, exp1, base2, exp2): if iarg == 1: if exp1 == 0: return function.zeros(base1.shape + base1.shape) @@ -403,21 +406,21 @@ def test_deduplication(self): class A(function.Custom): @staticmethod - def evalf(self): + def evalf(): pass @staticmethod - def partial_derivative(self, iarg): + def partial_derivative(iarg): pass class B(function.Custom): @staticmethod - def evalf(self): + def evalf(): pass @staticmethod - def partial_derivative(self, iarg): + def partial_derivative(iarg): pass a = A(args=(function.Argument('a', (2,3)),), shape=(), dtype=float).as_evaluable_array