diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 14e47b3fd..8ee3090f0 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -465,7 +465,7 @@ def _combine_loop_concatenates(self, outer_exclude): # start, stop and the concatenation length) are formed by # `loop_concatenate`-ing `func.shape[-1]`. If the shape is constant, # this can be simplified to a `Range`. - data = Tuple((Tuple((lc.func, lc.start, lc.stop, lc.shape[-1])) for lc in lcs)) + data = Tuple((Tuple(lc.funcdata) for lc in lcs)) # Combine `LoopConcatenate` instances in `data` excluding # `outer_exclude` and those that will be processed in a subsequent loop # (the remainder of `exclude`). The latter consists of loops that are @@ -3435,26 +3435,20 @@ def _simplified(self): class LoopConcatenate(Array): @types.apply_annotations - def __init__(self, func:asarray, start:asindex, stop:asindex, cc_length:asindex, index:types.strict[Argument], length:asindex): - if index.dtype != int or index.ndim != 0: - raise ValueError('expected an index with dtype int and dimension zero but got {}'.format(index)) - if not func.ndim: + def __init__(self, funcdata:asarrays, index:types.strict[Argument], length:asindex): + self.funcdata = funcdata + self.func, self.start, stop, *shape = funcdata + if not self.func.ndim: raise ValueError('expected an array with at least one axis') - if any(index in n.arguments for n in func.shape[:-1]): + if any(index in n.arguments for n in shape): raise ValueError('the shape of the function must not depend on the index') if index in length.arguments: raise ValueError('the length of the loop must not depend on the index') - if index in cc_length.arguments: - raise ValueError('the length of the concatenation axis must not depend on the index') - - self.func = func self.index = index self.length = length - self.start = start - self.stop = stop - self._lcc = LoopConcatenateCombined(((func, start, stop, cc_length),), index, length) - axes = (*(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in func._axes[:-1]), Axis(cc_length)) - super().__init__(args=[self._lcc], shape=axes, dtype=func.dtype) + self._lcc = LoopConcatenateCombined((self.funcdata,), index, length) + axes = (*(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in self.func._axes[:-1]), Axis(shape[-1])) + super().__init__(args=[self._lcc], shape=axes, dtype=self.func.dtype) def evalf(self, arg): return arg[0] @@ -3510,73 +3504,58 @@ def _loop_concatenate_deps(self): class LoopConcatenateCombined(Evaluable): + __cache__ = '_serialized' + @types.apply_annotations - def __init__(self, funcdata:types.tuple[asarrays], index:types.strict[Argument], length:asindex): + def __init__(self, funcdatas:types.tuple[asarrays], index:types.strict[Argument], length:asindex): + self._funcdatas = funcdatas if index.dtype != int or index.ndim != 0: raise ValueError('expected an index with dtype int and dimension zero but got {}'.format(index)) - if any(not func.ndim for func, start, stop, cc_length in funcdata): + self._funcs = tuple(func for func, start, stop, *shape in funcdatas) + if any(not func.ndim for func in self._funcs): raise ValueError('expected an array with at least one axis') - if any(index in n.arguments for func, start, stop, cc_length in funcdata for n in (*func.shape[:-1], cc_length)): + shapes = [Tuple(shape) for func, start, stop, *shape in funcdatas] + if any(index in shape.arguments for shape in shapes): raise ValueError('the shape of the function must not depend on the index') if index in length.arguments: raise ValueError('the length of the loop must not depend on the index') - - self._funcs, self._starts, self._stops, self._cc_lengths = zip(*funcdata) self._index = index self._length = length + self._invariants = [] + self._dependencies = [] + result = Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas]) + _populate_dependencies_sans_invariants(result, index, self._invariants, self._dependencies, set()) + assert (self._dependencies or self._invariants)[-1] == result + super().__init__(args=(Tuple(shapes), length, *self._invariants)) - invariants = [] - for func, start, stop, cc_length in funcdata: - invariants.extend(func.shape[:-1]) - invariants.append(cc_length) - invariants.append(length) - - dependencies = [] - cache = set() - for obj in itertools.chain(self._starts, self._stops, self._funcs): - _populate_dependencies_sans_invariants(obj, index, invariants, dependencies, cache) - - indices = {d: i for i, d in enumerate(itertools.chain(invariants, [index], dependencies))} - self._start_indices = tuple(map(indices.__getitem__, self._starts)) - self._stop_indices = tuple(map(indices.__getitem__, self._stops)) - self._result_indices = tuple(map(indices.__getitem__, self._funcs)) - self._serialized = tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in dependencies) - self._invariants = tuple(invariants) - - super().__init__(args=invariants) + @property + def _serialized(self): + indices = {d: i for i, d in enumerate(itertools.chain([self._index], self._invariants, self._dependencies))} + return tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in self._dependencies) - def evalf(self, *args): - i = 0 - results = [] - for func in self._funcs: - results.append(parallel.shempty(tuple(map(int, args[i:i+func.ndim])), dtype=func.dtype)) - i += func.ndim - length = int(args[i]) - with parallel.ctxrange('loop', length) as indices: + def evalf(self, shapes, length, *args): + serialized = self._serialized + results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)] + with parallel.ctxrange('loop', int(length)) as indices: for index in indices: - values = list(args) - values.append(numpy.array(index)) - values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in self._serialized) - for result, result_id, start_id, stop_id in zip(results, self._result_indices, self._start_indices, self._stop_indices): - result[...,int(values[start_id]):int(values[stop_id])] = values[result_id] + values = [numpy.array(index)] + values.extend(args) + values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in serialized) + for result, (start, stop, block) in zip(results, values[-1]): + result[...,start:stop] = block return tuple(results) - def evalf_withtimes(self, times, *args): + def evalf_withtimes(self, times, shapes, length, *args): + serialized = self._serialized times[self] = subtimes = collections.defaultdict(_Stats) - i = 0 - results = [] - for func in self._funcs: - with subtimes['concat', func]: - results.append(numpy.empty(tuple(map(int, args[i:i+func.ndim])), dtype=func.dtype)) - i += func.ndim - length = int(args[i]) + results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)] for index in range(length): - values = list(args) - values.append(numpy.array(index)) - values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in self._serialized) - for func, result, result_id, start_id, stop_id in zip(self._funcs, results, self._result_indices, self._start_indices, self._stop_indices): + values = [numpy.array(index)] + values.extend(args) + values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in serialized) + for func, result, (start, stop, block) in zip(self._funcs, results, values[-1]): with subtimes['concat', func]: - result[...,int(values[start_id]):int(values[stop_id])] = values[result_id] + result[...,start:stop] = block return tuple(results) def _node_tuple(self, cache, subgraph, times): @@ -3589,8 +3568,8 @@ def _node_tuple(self, cache, subgraph, times): subcache[self._index] = RegularNode('LoopIndex', (), dict(length=self._length._node(cache, subgraph, times)), (type(self).__name__, _Stats()), loopgraph) subtimes = times.get(self, collections.defaultdict(_Stats)) concats = [] - for func, start, stop, cc_length in zip(self._funcs, self._starts, self._stops, self._cc_lengths): - concat_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate((*func.shape[:-1], cc_length))} + for func, start, stop, *shape in self._funcdatas: + concat_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(shape)} concat_kwargs['start'] = start._node(subcache, loopgraph, subtimes) concat_kwargs['stop'] = stop._node(subcache, loopgraph, subtimes) concat_kwargs['func'] = func._node(subcache, loopgraph, subtimes) @@ -3995,15 +3974,14 @@ def _loop_concatenate_data(func, index, length): else: chunk_sizes = loop_concatenate(InsertAxis(func.shape[-1], 1), index, length) offsets = _SizesToOffsets(chunk_sizes) - start = _take(offsets, index, 0) - stop = _take(offsets, index+1, 0) - cc_length = _take(offsets, length, 0) - return func, start, stop, cc_length + start = Take(offsets, index) + stop = Take(offsets, index+1) + return (func, start, stop, *func.shape[:-1], Take(offsets, length)) def loop_concatenate(func, index, length): length = asindex(length) - func, start, stop, cc_length = _loop_concatenate_data(func, index, length) - return LoopConcatenate(func, start, stop, cc_length, index, length) + funcdata = _loop_concatenate_data(func, index, length) + return LoopConcatenate(funcdata, index, length) def loop_concatenate_combined(funcs, index, length): length = asindex(length) @@ -4011,7 +3989,7 @@ def loop_concatenate_combined(funcs, index, length): unique_funcs.extend(func for func in funcs if func not in unique_funcs) unique_func_data = tuple(_loop_concatenate_data(func, index, length) for func in unique_funcs) loop = LoopConcatenateCombined(unique_func_data, index, length) - return tuple(ArrayFromTuple(loop, unique_funcs.index(func), (*func.shape[:-1], cc_length), func.dtype) for func, start, stop, cc_length in unique_func_data) + return tuple(ArrayFromTuple(loop, unique_funcs.index(func), shape, func.dtype) for func, start, stop, *shape in unique_func_data) @replace def replace_arguments(value, arguments): diff --git a/tests/test_evaluable.py b/tests/test_evaluable.py index db1e20b12..d0bb745ce 100644 --- a/tests/test_evaluable.py +++ b/tests/test_evaluable.py @@ -823,8 +823,8 @@ class combine_loop_concatenates(TestCase): def test_same_index_same_length(self): i = evaluable.Argument('i', (), int) - A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3) - B = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 2), i*2, i*2+2, 6, i, 3) + A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3) + B = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 2), i*2, i*2+2, 6,), i, 3) actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) L = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3), (evaluable.InsertAxis(i, 2), i*2, i*2+2, 6)), i, 3) desired = evaluable.Tuple((evaluable.ArrayFromTuple(L, 0, (3,), int), evaluable.ArrayFromTuple(L, 1, (6,), int))) @@ -832,8 +832,8 @@ def test_same_index_same_length(self): def test_same_index_different_length(self): i = evaluable.Argument('i', (), int) - A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3) - B = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 4, i, 4) + A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3) + B = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 4,), i, 4) actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3) L2 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 4),), i, 4) @@ -843,8 +843,8 @@ def test_same_index_different_length(self): def test_different_index(self): i = evaluable.Argument('i', (), int) j = evaluable.Argument('j', (), int) - A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3) - B = evaluable.LoopConcatenate(evaluable.InsertAxis(j, 1), j, j+1, 3, j, 3) + A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3) + B = evaluable.LoopConcatenate((evaluable.InsertAxis(j, 1), j, j+1, 3,), j, 3) actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3) L2 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(j, 1), j, j+1, 3),), j, 3) @@ -853,8 +853,8 @@ def test_different_index(self): def test_nested_invariant(self): i = evaluable.Argument('i', (), int) - A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3) - B = evaluable.LoopConcatenate(A, i*3, i*3+3, 9, i, 3) + A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3) + B = evaluable.LoopConcatenate((A, i*3, i*3+3, 9,), i, 3) actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3) A_ = evaluable.ArrayFromTuple(L1, 0, (3,), int) @@ -866,8 +866,8 @@ def test_nested_invariant(self): def test_nested_variant(self): i = evaluable.Argument('i', (), int) j = evaluable.Argument('j', (), int) - A = evaluable.LoopConcatenate(evaluable.InsertAxis(i+j, 1), i, i+1, 3, i, 3) - B = evaluable.LoopConcatenate(A, j*3, j*3+3, 9, j, 3) + A = evaluable.LoopConcatenate((evaluable.InsertAxis(i+j, 1), i, i+1, 3,), i, 3) + B = evaluable.LoopConcatenate((A, j*3, j*3+3, 9,), j, 3) actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i+j, 1), i, i+1, 3),), i, 3) A_ = evaluable.ArrayFromTuple(L1, 0, (3,), int)