Skip to content

Commit

Permalink
restore simplification of LoopConcat shape args
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Mar 11, 2021
1 parent 3bd57c5 commit 0a27364
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 84 deletions.
126 changes: 52 additions & 74 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _combine_loop_concatenates(self, outer_exclude):
# start, stop and the concatenation length) are formed by
# `loop_concatenate`-ing `func.shape[-1]`. If the shape is constant,
# this can be simplified to a `Range`.
data = Tuple((Tuple((lc.func, lc.start, lc.stop, lc.shape[-1])) for lc in lcs))
data = Tuple((Tuple(lc.funcdata) for lc in lcs))
# Combine `LoopConcatenate` instances in `data` excluding
# `outer_exclude` and those that will be processed in a subsequent loop
# (the remainder of `exclude`). The latter consists of loops that are
Expand Down Expand Up @@ -3435,26 +3435,20 @@ def _simplified(self):
class LoopConcatenate(Array):

@types.apply_annotations
def __init__(self, func:asarray, start:asindex, stop:asindex, cc_length:asindex, index:types.strict[Argument], length:asindex):
if index.dtype != int or index.ndim != 0:
raise ValueError('expected an index with dtype int and dimension zero but got {}'.format(index))
if not func.ndim:
def __init__(self, funcdata:asarrays, index:types.strict[Argument], length:asindex):
self.funcdata = funcdata
self.func, self.start, stop, *shape = funcdata
if not self.func.ndim:
raise ValueError('expected an array with at least one axis')
if any(index in n.arguments for n in func.shape[:-1]):
if any(index in n.arguments for n in shape):
raise ValueError('the shape of the function must not depend on the index')
if index in length.arguments:
raise ValueError('the length of the loop must not depend on the index')
if index in cc_length.arguments:
raise ValueError('the length of the concatenation axis must not depend on the index')

self.func = func
self.index = index
self.length = length
self.start = start
self.stop = stop
self._lcc = LoopConcatenateCombined(((func, start, stop, cc_length),), index, length)
axes = (*(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in func._axes[:-1]), Axis(cc_length))
super().__init__(args=[self._lcc], shape=axes, dtype=func.dtype)
self._lcc = LoopConcatenateCombined((self.funcdata,), index, length)
axes = (*(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in self.func._axes[:-1]), Axis(shape[-1]))
super().__init__(args=[self._lcc], shape=axes, dtype=self.func.dtype)

def evalf(self, arg):
return arg[0]
Expand Down Expand Up @@ -3510,73 +3504,58 @@ def _loop_concatenate_deps(self):

class LoopConcatenateCombined(Evaluable):

__cache__ = '_serialized'

@types.apply_annotations
def __init__(self, funcdata:types.tuple[asarrays], index:types.strict[Argument], length:asindex):
def __init__(self, funcdatas:types.tuple[asarrays], index:types.strict[Argument], length:asindex):
self._funcdatas = funcdatas
if index.dtype != int or index.ndim != 0:
raise ValueError('expected an index with dtype int and dimension zero but got {}'.format(index))
if any(not func.ndim for func, start, stop, cc_length in funcdata):
self._funcs = tuple(func for func, start, stop, *shape in funcdatas)
if any(not func.ndim for func in self._funcs):
raise ValueError('expected an array with at least one axis')
if any(index in n.arguments for func, start, stop, cc_length in funcdata for n in (*func.shape[:-1], cc_length)):
shapes = [Tuple(shape) for func, start, stop, *shape in funcdatas]
if any(index in shape.arguments for shape in shapes):
raise ValueError('the shape of the function must not depend on the index')
if index in length.arguments:
raise ValueError('the length of the loop must not depend on the index')

self._funcs, self._starts, self._stops, self._cc_lengths = zip(*funcdata)
self._index = index
self._length = length
self._invariants = []
self._dependencies = []
result = Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas])
_populate_dependencies_sans_invariants(result, index, self._invariants, self._dependencies, set())
assert (self._dependencies or self._invariants)[-1] == result
super().__init__(args=(Tuple(shapes), length, *self._invariants))

invariants = []
for func, start, stop, cc_length in funcdata:
invariants.extend(func.shape[:-1])
invariants.append(cc_length)
invariants.append(length)

dependencies = []
cache = set()
for obj in itertools.chain(self._starts, self._stops, self._funcs):
_populate_dependencies_sans_invariants(obj, index, invariants, dependencies, cache)

indices = {d: i for i, d in enumerate(itertools.chain(invariants, [index], dependencies))}
self._start_indices = tuple(map(indices.__getitem__, self._starts))
self._stop_indices = tuple(map(indices.__getitem__, self._stops))
self._result_indices = tuple(map(indices.__getitem__, self._funcs))
self._serialized = tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in dependencies)
self._invariants = tuple(invariants)

super().__init__(args=invariants)
@property
def _serialized(self):
indices = {d: i for i, d in enumerate(itertools.chain([self._index], self._invariants, self._dependencies))}
return tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in self._dependencies)

def evalf(self, *args):
i = 0
results = []
for func in self._funcs:
results.append(parallel.shempty(tuple(map(int, args[i:i+func.ndim])), dtype=func.dtype))
i += func.ndim
length = int(args[i])
with parallel.ctxrange('loop', length) as indices:
def evalf(self, shapes, length, *args):
serialized = self._serialized
results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)]
with parallel.ctxrange('loop', int(length)) as indices:
for index in indices:
values = list(args)
values.append(numpy.array(index))
values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in self._serialized)
for result, result_id, start_id, stop_id in zip(results, self._result_indices, self._start_indices, self._stop_indices):
result[...,int(values[start_id]):int(values[stop_id])] = values[result_id]
values = [numpy.array(index)]
values.extend(args)
values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in serialized)
for result, (start, stop, block) in zip(results, values[-1]):
result[...,start:stop] = block
return tuple(results)

def evalf_withtimes(self, times, *args):
def evalf_withtimes(self, times, shapes, length, *args):
serialized = self._serialized
times[self] = subtimes = collections.defaultdict(_Stats)
i = 0
results = []
for func in self._funcs:
with subtimes['concat', func]:
results.append(numpy.empty(tuple(map(int, args[i:i+func.ndim])), dtype=func.dtype))
i += func.ndim
length = int(args[i])
results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)]
for index in range(length):
values = list(args)
values.append(numpy.array(index))
values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in self._serialized)
for func, result, result_id, start_id, stop_id in zip(self._funcs, results, self._result_indices, self._start_indices, self._stop_indices):
values = [numpy.array(index)]
values.extend(args)
values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in serialized)
for func, result, (start, stop, block) in zip(self._funcs, results, values[-1]):
with subtimes['concat', func]:
result[...,int(values[start_id]):int(values[stop_id])] = values[result_id]
result[...,start:stop] = block
return tuple(results)

def _node_tuple(self, cache, subgraph, times):
Expand All @@ -3589,8 +3568,8 @@ def _node_tuple(self, cache, subgraph, times):
subcache[self._index] = RegularNode('LoopIndex', (), dict(length=self._length._node(cache, subgraph, times)), (type(self).__name__, _Stats()), loopgraph)
subtimes = times.get(self, collections.defaultdict(_Stats))
concats = []
for func, start, stop, cc_length in zip(self._funcs, self._starts, self._stops, self._cc_lengths):
concat_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate((*func.shape[:-1], cc_length))}
for func, start, stop, *shape in self._funcdatas:
concat_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(shape)}
concat_kwargs['start'] = start._node(subcache, loopgraph, subtimes)
concat_kwargs['stop'] = stop._node(subcache, loopgraph, subtimes)
concat_kwargs['func'] = func._node(subcache, loopgraph, subtimes)
Expand Down Expand Up @@ -3995,23 +3974,22 @@ def _loop_concatenate_data(func, index, length):
else:
chunk_sizes = loop_concatenate(InsertAxis(func.shape[-1], 1), index, length)
offsets = _SizesToOffsets(chunk_sizes)
start = _take(offsets, index, 0)
stop = _take(offsets, index+1, 0)
cc_length = _take(offsets, length, 0)
return func, start, stop, cc_length
start = Take(offsets, index)
stop = Take(offsets, index+1)
return (func, start, stop, *func.shape[:-1], Take(offsets, length))

def loop_concatenate(func, index, length):
length = asindex(length)
func, start, stop, cc_length = _loop_concatenate_data(func, index, length)
return LoopConcatenate(func, start, stop, cc_length, index, length)
funcdata = _loop_concatenate_data(func, index, length)
return LoopConcatenate(funcdata, index, length)

def loop_concatenate_combined(funcs, index, length):
length = asindex(length)
unique_funcs = []
unique_funcs.extend(func for func in funcs if func not in unique_funcs)
unique_func_data = tuple(_loop_concatenate_data(func, index, length) for func in unique_funcs)
loop = LoopConcatenateCombined(unique_func_data, index, length)
return tuple(ArrayFromTuple(loop, unique_funcs.index(func), (*func.shape[:-1], cc_length), func.dtype) for func, start, stop, cc_length in unique_func_data)
return tuple(ArrayFromTuple(loop, unique_funcs.index(func), shape, func.dtype) for func, start, stop, *shape in unique_func_data)

@replace
def replace_arguments(value, arguments):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,17 +823,17 @@ class combine_loop_concatenates(TestCase):

def test_same_index_same_length(self):
i = evaluable.Argument('i', (), int)
A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3)
B = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 2), i*2, i*2+2, 6, i, 3)
A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3)
B = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 2), i*2, i*2+2, 6,), i, 3)
actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set())
L = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3), (evaluable.InsertAxis(i, 2), i*2, i*2+2, 6)), i, 3)
desired = evaluable.Tuple((evaluable.ArrayFromTuple(L, 0, (3,), int), evaluable.ArrayFromTuple(L, 1, (6,), int)))
self.assertEqual(actual, desired)

def test_same_index_different_length(self):
i = evaluable.Argument('i', (), int)
A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3)
B = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 4, i, 4)
A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3)
B = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 4,), i, 4)
actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set())
L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3)
L2 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 4),), i, 4)
Expand All @@ -843,8 +843,8 @@ def test_same_index_different_length(self):
def test_different_index(self):
i = evaluable.Argument('i', (), int)
j = evaluable.Argument('j', (), int)
A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3)
B = evaluable.LoopConcatenate(evaluable.InsertAxis(j, 1), j, j+1, 3, j, 3)
A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3)
B = evaluable.LoopConcatenate((evaluable.InsertAxis(j, 1), j, j+1, 3,), j, 3)
actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set())
L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3)
L2 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(j, 1), j, j+1, 3),), j, 3)
Expand All @@ -853,8 +853,8 @@ def test_different_index(self):

def test_nested_invariant(self):
i = evaluable.Argument('i', (), int)
A = evaluable.LoopConcatenate(evaluable.InsertAxis(i, 1), i, i+1, 3, i, 3)
B = evaluable.LoopConcatenate(A, i*3, i*3+3, 9, i, 3)
A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, 1), i, i+1, 3,), i, 3)
B = evaluable.LoopConcatenate((A, i*3, i*3+3, 9,), i, 3)
actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set())
L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, 1), i, i+1, 3),), i, 3)
A_ = evaluable.ArrayFromTuple(L1, 0, (3,), int)
Expand All @@ -866,8 +866,8 @@ def test_nested_invariant(self):
def test_nested_variant(self):
i = evaluable.Argument('i', (), int)
j = evaluable.Argument('j', (), int)
A = evaluable.LoopConcatenate(evaluable.InsertAxis(i+j, 1), i, i+1, 3, i, 3)
B = evaluable.LoopConcatenate(A, j*3, j*3+3, 9, j, 3)
A = evaluable.LoopConcatenate((evaluable.InsertAxis(i+j, 1), i, i+1, 3,), i, 3)
B = evaluable.LoopConcatenate((A, j*3, j*3+3, 9,), j, 3)
actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set())
L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i+j, 1), i, i+1, 3),), i, 3)
A_ = evaluable.ArrayFromTuple(L1, 0, (3,), int)
Expand Down

0 comments on commit 0a27364

Please sign in to comment.