Skip to content

Commit

Permalink
restore simplification of LoopSum shape args
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Mar 11, 2021
1 parent 55dff85 commit 3bd57c5
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3306,46 +3306,60 @@ def _simplified(self):

class LoopSum(Array):

__cache__ = '_serialized'

def prepare_funcdata(arg):
# separate shape from array to make it simplifiable (annotations are
# treated as preprocessor, which means the processed value is returned by
# self.__reduce__)
if isinstance(arg, tuple):
return arg
arg = asarray(arg)
return (arg, *arg.shape)

@types.apply_annotations
def __init__(self, func: asarray, index:types.strict[Argument], length: asarray):
def __init__(self, funcdata:prepare_funcdata, index:types.strict[Argument], length:asindex):
shape = Tuple(funcdata[1:])
if index.dtype != int or index.ndim != 0:
raise ValueError('expected an index with dtype int and dimension zero but got {}'.format(index))
if any(index in n.arguments for n in func.shape):
if index in shape.arguments:
raise ValueError('the shape of the function must not depend on the index')
if index in length.arguments:
raise ValueError('the length of the loop must not depend on the index')

self.func = func
self.func = funcdata[0]
self.index = index
self.length = length
self._invariants = []
self._dependencies = []
_populate_dependencies_sans_invariants(self.func, index, self._invariants, self._dependencies, set())
assert (self._dependencies or self._invariants or [index])[-1] == self.func
axes = tuple(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in self.func._axes)
super().__init__(args=(shape, length, *self._invariants), shape=axes, dtype=self.func.dtype)

invariants = [*func.shape, length]
dependencies = []
_populate_dependencies_sans_invariants(func, index, invariants, dependencies, set())
indices = {d: i for i, d in enumerate(itertools.chain(invariants, [index], dependencies))}
self._result_index = indices[func]
self._serialized = tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in dependencies)

axes = tuple(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in func._axes)
super().__init__(args=invariants, shape=axes, dtype=func.dtype)
@property
def _serialized(self):
indices = {d: i for i, d in enumerate(itertools.chain([self.index], self._invariants, self._dependencies))}
return tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in self._dependencies)

def evalf(self, *args):
result = numpy.zeros(tuple(map(int, args[:self.ndim])), self.dtype)
for index in range(int(args[self.ndim])):
values = list(args)
values.append(numpy.array(index))
values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in self._serialized)
result += values[self._result_index]
def evalf(self, shape, length, *args):
serialized = self._serialized
result = numpy.zeros(shape, self.dtype)
for index in range(length):
values = [numpy.array(index)]
values.extend(args)
values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in serialized)
result += values[-1]
return result

def evalf_withtimes(self, times, *args):
def evalf_withtimes(self, times, shape, length, *args):
serialized = self._serialized
times[self] = subtimes = collections.defaultdict(_Stats)
result = numpy.zeros(tuple(map(int, args[:self.ndim])), self.dtype)
for index in range(int(args[self.ndim])):
values = list(args)
values.append(numpy.array(index))
values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in self._serialized)
result += values[self._result_index]
result = numpy.zeros(shape, self.dtype)
for index in range(length):
values = [numpy.array(index)]
values.extend(args)
values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in serialized)
result += values[-1]
return result

def _derivative(self, var, seen):
Expand Down

0 comments on commit 3bd57c5

Please sign in to comment.