From 09861e6f18df060d198bc40cac2488e01ac7ab70 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 20 Oct 2023 13:38:14 +0200 Subject: [PATCH] fix picklability of sliced function array This patch changed the implementation of _takeslice such that it wraps evaluable.Range directly rather than a lambda function, which caused the resulting array to not be picklable. A unit test is added to safeguard against similar mistakes in future. --- nutils/function.py | 9 ++------- tests/test_function.py | 7 +++++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nutils/function.py b/nutils/function.py index a853167ff..0e279369c 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -2733,12 +2733,6 @@ def get(__array: IntoArray, __axis: int, __index: IntoArray) -> Array: return numpy.take(Array.cast(__array), Array.cast(__index, dtype=int, ndim=0), __axis) -def _range(__length: int, __offset: int) -> Array: - length = Array.cast(__length, dtype=int, ndim=0) - offset = Array.cast(__offset, dtype=int, ndim=0) - return _Wrapper(lambda l, o: evaluable.Range(l) + o, _WithoutPoints(length), _WithoutPoints(offset), shape=(__length,), dtype=int) - - def _takeslice(__array: IntoArray, __s: slice, __axis: int) -> Array: array = Array.cast(__array) s = __s @@ -2749,7 +2743,8 @@ def _takeslice(__array: IntoArray, __s: slice, __axis: int) -> Array: stop = n if s.stop is None else s.stop if s.stop >= 0 else s.stop + n if start == 0 and stop == n: return array - index = _range(stop-start, start) + length = stop - start + index = _Wrapper(evaluable.Range, _WithoutPoints(_Constant(length)), shape=(length,), dtype=int) + start elif isinstance(n, numbers.Integral): index = Array.cast(numpy.arange(*s.indices(int(n)))) else: diff --git a/tests/test_function.py b/tests/test_function.py index 7f7480b0d..5c26206bf 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -6,6 +6,7 @@ import warnings as _builtin_warnings import functools import fractions +import pickle class Array(TestCase): @@ -171,6 +172,12 @@ def test_lower_eval(self): desired = self.n_op(*self.args) self.assertArrayAlmostEqual(actual, desired, decimal=15) + def test_pickle(self): + f = self.op(*self.args) + s = pickle.dumps(f) + f_ = pickle.loads(s) + self.assertEqual(f.as_evaluable_array, f_.as_evaluable_array) + def generate(*shape, real, imag, zero, negative): 'generate array values that cover certain numerical classes'