From 01906a2e8a433a8bdd8dfe5af2fa9f8998460186 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Wed, 13 Jan 2021 17:07:55 +0100 Subject: [PATCH] add cached __index__ method to constant int scalar --- nutils/evaluable.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index f66dd46d1..894bfa43f 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -233,7 +233,7 @@ def wrapped(target, *funcargs, **funckwargs): class Evaluable(types.Singleton): 'Base class' - __slots__ = '__args', + __slots__ = '__args' __cache__ = 'dependencies', 'arguments', 'ordereddeps', 'dependencytree', 'optimized_for_numpy', '_loop_concatenate_deps' @types.apply_annotations @@ -807,7 +807,7 @@ class Array(Evaluable, metaclass=_ArrayMeta): The dtype of the array elements. ''' - __slots__ = '_axes', 'dtype' + __slots__ = '_axes', 'dtype', '__index' __cache__ = 'blocks', 'assparse', '_assparse', '_as_canonical_length' __array_priority__ = 1. # http://stackoverflow.com/questions/7042496/numpy-coercion-problem-for-left-sided-binary-operator/7057530#7057530 @@ -863,6 +863,15 @@ def __iter__(self): raise TypeError('iteration over a 0-d array') return (self[i,...] for i in range(self.shape[0])) + def __index__(self): + try: + index = self.__index + except AttributeError: + if self.ndim or self.dtype not in (int, bool) or not self.isconstant: + raise TypeError('cannot convert {!r} to int'.format(self)) + index = self.__index = int(self.simplified.eval()) + return index + size = property(lambda self: util.product(self.shape) if self.ndim else 1) T = property(lambda self: transpose(self)) @@ -877,6 +886,7 @@ def __iter__(self): __pow__ = power __abs__ = lambda self: abs(self) __mod__ = lambda self, other: mod(self, other) + __int__ = __index__ __str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str)) _shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(map(form, self._axes)) if hasattr(self, '_axes') else '?')