Skip to content

Commit

Permalink
add cached __index__ method to constant int scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Mar 11, 2021
1 parent f6f51d8 commit 01906a2
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def wrapped(target, *funcargs, **funckwargs):
class Evaluable(types.Singleton):
'Base class'

__slots__ = '__args',
__slots__ = '__args'
__cache__ = 'dependencies', 'arguments', 'ordereddeps', 'dependencytree', 'optimized_for_numpy', '_loop_concatenate_deps'

@types.apply_annotations
Expand Down Expand Up @@ -807,7 +807,7 @@ class Array(Evaluable, metaclass=_ArrayMeta):
The dtype of the array elements.
'''

__slots__ = '_axes', 'dtype'
__slots__ = '_axes', 'dtype', '__index'
__cache__ = 'blocks', 'assparse', '_assparse', '_as_canonical_length'

__array_priority__ = 1. # http://stackoverflow.com/questions/7042496/numpy-coercion-problem-for-left-sided-binary-operator/7057530#7057530
Expand Down Expand Up @@ -863,6 +863,15 @@ def __iter__(self):
raise TypeError('iteration over a 0-d array')
return (self[i,...] for i in range(self.shape[0]))

def __index__(self):
try:
index = self.__index
except AttributeError:
if self.ndim or self.dtype not in (int, bool) or not self.isconstant:
raise TypeError('cannot convert {!r} to int'.format(self))
index = self.__index = int(self.simplified.eval())
return index

size = property(lambda self: util.product(self.shape) if self.ndim else 1)
T = property(lambda self: transpose(self))

Expand All @@ -877,6 +886,7 @@ def __iter__(self):
__pow__ = power
__abs__ = lambda self: abs(self)
__mod__ = lambda self, other: mod(self, other)
__int__ = __index__
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str))
_shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(map(form, self._axes)) if hasattr(self, '_axes') else '?')

Expand Down

0 comments on commit 01906a2

Please sign in to comment.