From d1916e0c31df5578b2c897726ccb9f3da3cf31d0 Mon Sep 17 00:00:00 2001 From: Joost van Zwieten Date: Thu, 18 Apr 2024 16:12:10 +0200 Subject: [PATCH] wip: cache const eval --- nutils/evaluable.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 1e71fcb0b..d00f83c6b 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -89,7 +89,7 @@ def _equals_scalar_constant(arg: 'Array', value: Dtype): assert isinstance(arg, Array) and arg.ndim == 0, f'arg={arg!r}' assert arg.dtype == type(value), f'arg.dtype={arg.dtype}, type(value)={type(value)}' if arg.isconstant and not arg._loops: - return arg.eval() == value + return arg._const_value == value def _equals_simplified(arg1: 'Array', arg2: 'Array'): @@ -105,7 +105,7 @@ def _equals_simplified(arg1: 'Array', arg2: 'Array'): if arg1.arguments != arg2.arguments: return False if arg1.isconstant: # implies arg2.isconstant - return numpy.all(arg1.eval() == arg2.eval()) + return numpy.all(arg1._const_value == arg2._const_value) def equalshape(N: typing.Tuple['Array', ...], M: typing.Tuple['Array', ...]): @@ -176,6 +176,15 @@ def asciitree(self, richoutput=False): def __str__(self): return self.__class__.__name__ + @cached_property + def _const_value(self): + args = tuple(map(_pyast.Variable('args').get_item, map(_pyast.LiteralInt, range(len(self.__args) + 1)))) + try: + expr = self._compile_expression(*args) + except NotImplementedError: + return self.eval() + return eval(expr.py_expr, dict(numpy=numpy, numeric=numeric, parallel=parallel, poly=poly, collections=collections, multiprocessing=multiprocessing, args=(self, *(arg._const_value for arg in self.__args)))) + @property def eval(self): '''Evaluate function on a specified element, point set.''' @@ -571,7 +580,7 @@ def __index__(self): except AttributeError: if self.ndim or self.dtype not in (int, bool) or not self.isconstant: raise TypeError('cannot convert {!r} to int'.format(self)) - index = self.__index = int(self.simplified.eval()) + index = self.__index = int(self.simplified._const_value) return index T = property(lambda self: transpose(self, tuple(range(self.ndim-1, -1, -1)))) @@ -941,7 +950,7 @@ def _takediag(self, axis1, axis2): def _take(self, index, axis): if index.isconstant: - index_ = index.eval() + index_ = index._const_value return constant(self.value.take(index_, axis)) def _power(self, n): @@ -1605,7 +1614,7 @@ def _inflations(self): and self.shape[axis].isconstant and all(dofmap.isconstant and not dofmap._loops for dofmap in dofmaps)): mask = numpy.zeros(int(self.shape[axis]), dtype=bool) for dofmap in dofmaps: - mask[dofmap.eval()] = True + mask[dofmap._const_value] = True if mask.all(): # axis adds up to dense continue inflations.append((axis, types.frozendict((dofmap, util.sum(parts[dofmap] for parts in (parts1, parts2) if dofmap in parts)) for dofmap in dofmaps))) @@ -2017,7 +2026,7 @@ def _compile_expression(self, py_self, func, power): def _derivative(self, var, seen): if self.power.isconstant: - p = self.power.eval() + p = self.power._const_value return einsum('A,A,AB->AB', constant(p), power(self.func, p - (p != 0)), derivative(self.func, var, seen)) if self.dtype == complex: raise NotImplementedError('The complex derivative is not implemented.') @@ -2976,7 +2985,7 @@ def _unravel(self, axis, shape): return Inflate(unravel(self.func, axis, shape), self.dofmap, self.length) def _sign(self): - if self.dofmap.isconstant and _isunique(self.dofmap.eval()): + if self.dofmap.isconstant and _isunique(self.dofmap._const_value): return Inflate(Sign(self.func), self.dofmap, self.length) @cached_property @@ -3043,7 +3052,7 @@ def __init__(self, inflateidx, takeidx): def _simplified(self): if self.isconstant: - return Tuple(tuple(map(constant, self.eval()))) + return Tuple(tuple(map(constant, self._const_value))) def __iter__(self): shape = ArrayFromTuple(self, index=2, shape=(), dtype=int), @@ -3210,7 +3219,7 @@ def _compile_expression(self, py_self, where): def _simplified(self): if self.isconstant: - return constant(self.eval()) + return constant(self._const_value) def _intbounds_impl(self): return 0, self.shape[0]._intbounds[1] @@ -4031,7 +4040,7 @@ def _simplified(self): if isinstance(lower_length, int) and lower_length == upper_length and -lower_length <= lower_index and upper_index < 0: return self.index + lower_length if self.length.isconstant and self.index.isconstant: - return constant(self.eval()) + return constant(self._const_value) def _intbounds_impl(self): lower_length, upper_length = self.length._intbounds @@ -4964,7 +4973,7 @@ def take(arg: Array, index: Array, axis: int): assert _equals_simplified(index.shape[0], length) index = Find(index) elif index.isconstant: - index_ = index.eval() + index_ = index._const_value ineg = numpy.less(index_, 0) if not length.isconstant: if ineg.any():