Skip to content

Commit

Permalink
fix picklability of sliced function array
Browse files Browse the repository at this point in the history
This patch changed the implementation of _takeslice such that it wraps
evaluable.Range directly rather than a lambda function, which caused the
resulting array to not be picklable. A unit test is added to safeguard against
similar mistakes in future.
  • Loading branch information
gertjanvanzwieten committed Nov 29, 2023
1 parent d024e38 commit 3929717
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
9 changes: 2 additions & 7 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2733,12 +2733,6 @@ def get(__array: IntoArray, __axis: int, __index: IntoArray) -> Array:
return numpy.take(Array.cast(__array), Array.cast(__index, dtype=int, ndim=0), __axis)


def _range(__length: int, __offset: int) -> Array:
length = Array.cast(__length, dtype=int, ndim=0)
offset = Array.cast(__offset, dtype=int, ndim=0)
return _Wrapper(lambda l, o: evaluable.Range(l) + o, _WithoutPoints(length), _WithoutPoints(offset), shape=(__length,), dtype=int)


def _takeslice(__array: IntoArray, __s: slice, __axis: int) -> Array:
array = Array.cast(__array)
s = __s
Expand All @@ -2749,7 +2743,8 @@ def _takeslice(__array: IntoArray, __s: slice, __axis: int) -> Array:
stop = n if s.stop is None else s.stop if s.stop >= 0 else s.stop + n
if start == 0 and stop == n:
return array
index = _range(stop-start, start)
length = stop - start
index = _Wrapper(evaluable.Range, _WithoutPoints(_Constant(length)), shape=(length,), dtype=int) + start
elif isinstance(n, numbers.Integral):
index = Array.cast(numpy.arange(*s.indices(int(n))))
else:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings as _builtin_warnings
import functools
import fractions
import pickle


class Array(TestCase):
Expand Down Expand Up @@ -171,6 +172,12 @@ def test_lower_eval(self):
desired = self.n_op(*self.args)
self.assertArrayAlmostEqual(actual, desired, decimal=15)

def test_pickle(self):
f = self.op(*self.args)
s = pickle.dumps(f)
f_ = pickle.loads(s)
self.assertEqual(f.as_evaluable_array, f_.as_evaluable_array)


def generate(*shape, real, imag, zero, negative):
'generate array values that cover certain numerical classes'
Expand Down

0 comments on commit 3929717

Please sign in to comment.