diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 215546495..61a4088d9 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -450,7 +450,7 @@ def _simplified(self): def optimized_for_numpy(self): retval = self.simplified._optimized_for_numpy1() or self retval = retval._deep_flatten_constants() or retval - return retval._combine_loop_concatenates(frozenset()) + return retval._combine_loops(frozenset()) @replace(depthfirst=True, recursive=True) def _optimized_for_numpy1(obj): @@ -469,49 +469,37 @@ def _deep_flatten_constants(self): return self._flatten_constant() @cached_property - def _loop_concatenate_deps(self): + def _loop_deps(self) -> typing.Tuple['Loop', ...]: deps = [] - for arg in self.__args: - deps += [dep for dep in arg._loop_concatenate_deps if dep not in deps] + deps.extend(loop for arg in self.__args for loop in arg._loop_deps if loop not in deps) return tuple(deps) - def _combine_loop_concatenates(self, outer_exclude): + def _combine_loops(self, candidates): + candidates = list(candidates) while True: - exclude = set(outer_exclude) + exclude = set() combine = {} - # Collect all top-level `LoopConcatenate` instances in `combine` and all - # their dependent `LoopConcatenate` instances in `exclude`. - for lc in self._loop_concatenate_deps: - lcs = combine.setdefault(lc.index, []) - if lc not in lcs: - lcs.append(lc) - exclude.update(set(lc._loop_concatenate_deps) - {lc}) - # Combine top-level `LoopConcatenate` instances excluding those in - # `exclude`. + # Collect all top-level loops in `combine` and all their dependent + # loops instances in `exclude`. + for loop in candidates: + loops = combine.setdefault(loop.index, []) + if loop not in loops: + loops.append(loop) + exclude.update(set(loop._loop_deps) - {loop}) + # Combine top-level loop instances excluding those in `exclude`. replacements = {} - for index, lcs in combine.items(): - lcs = [lc for lc in lcs if lc not in exclude] - if not lcs: + for index, loops in combine.items(): + loops = tuple(loop for loop in loops if loop not in exclude) + if len(loops) <= 1: continue - # We're extracting data from `LoopConcatenate` in favor of using - # `loop_concatenate_combined(lcs, ...)` because the later requires - # reapplying simplifications that are already applied in the former. - # For example, in `loop_concatenate_combined` the offsets (used by - # start, stop and the concatenation length) are formed by - # `loop_concatenate`-ing `func.shape[-1]`. If the shape is constant, - # this can be simplified to a `Range`. - data = Tuple(tuple(Tuple(lc.funcdata) for lc in lcs)) - # Combine `LoopConcatenate` instances in `data` excluding - # `outer_exclude` and those that will be processed in a subsequent loop - # (the remainder of `exclude`). The latter consists of loops that are - # invariant w.r.t. the current loop `index`. - data = data._combine_loop_concatenates(exclude) - combined = LoopConcatenateCombined(tuple(map(tuple, data)), index._name, index.length) - for i, lc in enumerate(lcs): - intbounds = dict(zip(('_lower', '_upper'), lc._intbounds)) if lc.dtype == int else {} - replacements[lc] = ArrayFromTuple(combined, i, lc.shape, lc.dtype, **intbounds) + candidates = [loop for loop in candidates if loop not in loops] + combined = LoopTuple(loops, index.name, index.length) + combined = combined._combine_loops(combined._nested_loops) + for i, loop in enumerate(loops): + intbounds = dict(zip(('_lower', '_upper'), loop._intbounds)) if loop.dtype == int else {} + replacements[loop] = ArrayFromTuple(combined, i, loop.shape, loop.dtype, **intbounds) if replacements: - self = replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None, recursive=False, depthfirst=False)(self) + self = replace(lambda key: replacements.get(key) if isinstance(key, Loop) else None, recursive=False, depthfirst=False)(self) else: return self @@ -4236,6 +4224,7 @@ class _LoopIndex(Argument): def __init__(self, name: str, length: Array): assert isinstance(name, str), f'name={name!r}' assert _isindex(length), f'length={length!r}' + self.name = name self.length = length super().__init__(name, (), int) @@ -4261,72 +4250,146 @@ def _simplified(self): return Zeros((), int) -class LoopSum(Array): +class Loop(Evaluable): + '''Base class for evaluable loops. - def __init__(self, func: Array, shape: typing.Tuple[Array, ...], index_name: str, length: Array): - assert isinstance(func, Array) and func.dtype != bool, f'func={func!r}' - assert isinstance(shape, tuple) and all(_isindex(n) for n in shape), f'shape={shape!r}' + Subclasses must implement + + * method `evalf_loop_init(init_arg)` and + * method `evalf_loop_body(output, body_arg)`. + ''' + + def __init__(self, index_name: str, length: Array, init_arg: Evaluable, body_arg: Evaluable, *args, **kwargs): assert isinstance(index_name, str), f'index_name={index_name!r}' - assert _isindex(length), f'length={length!r}' - assert func.ndim == len(shape) - self.index = loop_index(index_name, length) - if any(self.index in n.arguments for n in shape): - raise ValueError('the shape of the function must not depend on the index') - self.func = func - self._invariants, self._dependencies = _dependencies_sans_invariants(func, self.index) - super().__init__(args=(Tuple(shape), length, *self._invariants), shape=shape, dtype=func.dtype) + assert isinstance(length, Array), f'length={length!r}' + assert isinstance(init_arg, Evaluable), f'init_arg={init_arg!r}' + assert isinstance(body_arg, Evaluable), f'body_arg={init_arg!r}' + self.index_name = index_name + self.length = length + self.index = _LoopIndex(index_name, length) + self.init_arg = init_arg + self.body_arg = body_arg + if self.index in init_arg.arguments: + raise ValueError('the loop initialization arguments must not depend on the index') + self._invariants, self._dependencies = _dependencies_sans_invariants(body_arg, self.index) + super().__init__(args=(length, init_arg, *self._invariants), *args, **kwargs) @cached_property def _serialized_loop(self): indices = {d: i for i, d in enumerate(itertools.chain([self.index], self._invariants, self._dependencies))} return tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in self._dependencies) - # This property is a derivation of `_serialized` where the `Evaluable` - # instances are mapped to the `evalf` methods of the instances. Asserting - # that functions are immutable is difficult and currently - # `types._isimmutable` marks all functions as mutable. Since the - # `types.CacheMeta` machinery asserts immutability of the property, we have - # to resort to a regular `functools.cached_property`. Nevertheless, this - # property should be treated as if it is immutable. @cached_property def _serialized_loop_evalf(self): return tuple((dep.evalf, indices) for dep, indices in self._serialized_loop) - def evalf(self, shape, length, *args): + def evalf(self, length, init_arg, *invariants): serialized_evalf = self._serialized_loop_evalf - result = numpy.zeros(shape, self.dtype) - for index in range(length): - values = [numpy.array(index)] - values.extend(args) - values.extend(op_evalf(*[values[i] for i in indices]) for op_evalf, indices in serialized_evalf) - result += values[-1] - return result + output = self.evalf_loop_init(init_arg) + with parallel.ctxrange('loop {}'.format(self.index.name), int(length)) as indices: + for index in indices: + values = [numpy.array(index)] + values.extend(invariants) + values.extend(op_evalf(*[values[i] for i in indices]) for op_evalf, indices in serialized_evalf) + self.evalf_loop_body(output, values[-1]) + return output - def evalf_withtimes(self, times, shape, length, *args): + def evalf_withtimes(self, times, length, init_arg, *invariants): serialized = self._serialized_loop subtimes = times.setdefault(self, collections.defaultdict(_Stats)) - result = numpy.zeros(shape, self.dtype) + output = self.evalf_loop_init(init_arg) for index in range(length): values = [numpy.array(index)] - values.extend(args) + values.extend(invariants) values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in serialized) - result += values[-1] - return result + self.evalf_loop_body_withtimes(subtimes, output, values[-1]) + return output - def _derivative(self, var, seen): - return loop_sum(derivative(self.func, var, seen), self.index) + def evalf_loop_body_withtimes(self, times, output, body_arg): + with times[self]: + self.evalf_loop_body(output, body_arg) def _node(self, cache, subgraph, times): - if self in cache: - return cache[self] - subcache = {} - for arg in self._Evaluable__args: - subcache[arg] = arg._node(cache, subgraph, times) + if (cached := cache.get(self)) is not None: + return cached + for arg in itertools.chain(self._invariants, (self.init_arg,)): + arg._node(cache, subgraph, times) + loopcache = cache.copy() + loopcache.pop(self.index, None) loopgraph = Subgraph('Loop', subgraph) - subtimes = times.get(self, collections.defaultdict(_Stats)) - sum_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(self.shape)} - sum_kwargs['func'] = self.func._node(subcache, loopgraph, subtimes) - cache[self] = node = RegularNode('LoopSum', (), sum_kwargs, (type(self).__name__, subtimes['sum']), loopgraph) + looptimes = times.get(self, collections.defaultdict(_Stats)) + cache[self] = node = self._node_loop_body(loopcache, loopgraph, looptimes) + return node + + @property + def _loop_deps(self) -> typing.Tuple['Loop', ...]: + deps = [self] + args = itertools.chain(self._invariants, (self.init_arg,)) + deps.extend(loop for arg in args for loop in arg._loop_deps if loop not in deps) + return tuple(deps) + + @cached_property + def _nested_loops(self) -> typing.Tuple['Loop', ...]: + nested = [] + nested.extend(loop for arg in self._dependencies for loop in arg._loop_deps if loop not in nested) + deps = self._loop_deps + return tuple(loop for loop in nested if loop not in deps) + + +class LoopTuple(Loop): + + def __init__(self, loops: typing.Tuple[Loop], index_name: str, length: Array): + assert isinstance(loops, tuple) and all(isinstance(loop, Loop) and loop.index_name == index_name and loop.length == length for loop in loops), f'loops={loops}' + self.loops = loops + super().__init__( + index_name=index_name, + length=length, + init_arg=Tuple(tuple(loop.init_arg for loop in loops)), + body_arg=Tuple(tuple(loop.body_arg for loop in loops)), + ) + + def evalf_loop_init(self, args): + return tuple(loop.evalf_loop_init(arg) for loop, arg in zip(self.loops, args)) + + def evalf_loop_body(self, outputs, args): + for loop, output, arg in zip(self.loops, outputs, args): + loop.evalf_loop_body(output, arg) + + def evalf_loop_body_withtimes(self, times, outputs, args): + for loop, output, arg in zip(self.loops, outputs, args): + loop.evalf_loop_body_withtimes(times, output, arg) + + def _node_loop_body(self, cache, subgraph, times): + if (cached := cache.get(self)) is not None: + return cached + cache[self] = node = TupleNode(tuple(item._node_loop_body(cache, subgraph, times) for item in self.loops), metadata=(type(self).__name__, times[self]), subgraph=subgraph) + return node + + +class LoopSum(Loop, Array): + + def __init__(self, func: Array, shape: typing.Tuple[Array, ...], index_name: str, length: Array): + assert isinstance(func, Array) and func.dtype != bool, f'func={func!r}' + assert func.ndim == len(shape) + self.func = func + super().__init__(init_arg=Tuple(shape), body_arg=func, index_name=index_name, length=length, shape=shape, dtype=func.dtype) + + def evalf_loop_init(self, shape): + return parallel.shzeros(tuple(n.__index__() for n in shape), dtype=self.dtype) + + @staticmethod + def evalf_loop_body(output, func): + output += func + + def _derivative(self, var, seen): + return loop_sum(derivative(self.func, var, seen), self.index) + + def _node_loop_body(self, cache, subgraph, times): + if (cached := cache.get(self)) is not None: + return cached + kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(self.shape)} + kwargs['func'] = self.func._node(cache, subgraph, times) + cache[self] = node = RegularNode('LoopSum', (), kwargs, (type(self).__name__, times[self]), subgraph) return node def _simplified(self): @@ -4406,39 +4469,40 @@ def _intbounds_impl(self): return 0, (0 if n == 0 or m == 0 else n * m) -class LoopConcatenate(Array): +class LoopConcatenate(Loop, Array): - def __init__(self, funcdata: typing.Tuple[Array, ...], index_name: str, length: Array): - assert isinstance(funcdata, tuple) and all(isinstance(d, Array) for d in funcdata), f'funcdata={funcdata!r}' - assert isinstance(index_name, str), f'index_name={index_name!r}' - assert _isindex(length), f'length={length!r}' - self.funcdata = funcdata - self.func, self.start, stop, *shape = funcdata - self.index = loop_index(index_name, length) + def __init__(self, func: Array, start: Array, stop: Array, shape: typing.Tuple[Array, ...], index_name: str, length: Array): + assert isinstance(func, Array), f'func={func}' + assert isinstance(start, Array), f'start={start}' + assert isinstance(stop, Array), f'stop={stop}' + assert func.ndim == len(shape) + self.func = func + self.start = start + self.stop = stop if not self.func.ndim: raise ValueError('expected an array with at least one axis') - if any(self.index in n.arguments for n in shape): - raise ValueError('the shape of the function must not depend on the index') - self._lcc = LoopConcatenateCombined((self.funcdata,), index_name, length) - super().__init__(args=(self._lcc,), shape=tuple(shape), dtype=self.func.dtype) + super().__init__(init_arg=Tuple(shape), body_arg=Tuple((func, start, stop)), index_name=index_name, length=length, shape=shape, dtype=func.dtype) - @staticmethod - def evalf(arg): - return arg[0] + def evalf_loop_init(self, shape): + return parallel.shempty(tuple(n.__index__() for n in shape), dtype=self.dtype) - def evalf_withtimes(self, times, arg): - with times[self]: - return arg[0] + @staticmethod + def evalf_loop_body(output, arg): + func, start, stop = arg + output[..., start:stop] = func def _derivative(self, var, seen): return Transpose.from_end(loop_concatenate(Transpose.to_end(derivative(self.func, var, seen), self.ndim-1), self.index), self.ndim-1) - def _node(self, cache, subgraph, times): - if self in cache: - return cache[self] - else: - cache[self] = node = self._lcc._node(cache, subgraph, times)[0] - return node + def _node_loop_body(self, cache, subgraph, times): + if (cached := cache.get(self)) is not None: + return cached + kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(self.shape)} + kwargs['start'] = self.start._node(cache, subgraph, times) + kwargs['stop'] = self.stop._node(cache, subgraph, times) + kwargs['func'] = self.func._node(cache, subgraph, times) + cache[self] = node = RegularNode('LoopConcatenate', (), kwargs, (type(self).__name__, times[self]), subgraph) + return node def _simplified(self): if iszero(self.func): @@ -4485,93 +4549,10 @@ def _assparse(self): chunks.append(tuple(loop_concatenate(_flat(arr), self.index) for arr in (*indices, last_index, values))) return tuple(chunks) - @property - def _loop_concatenate_deps(self): - return (self,) + super()._loop_concatenate_deps - def _intbounds_impl(self): return self.func._intbounds -class LoopConcatenateCombined(Evaluable): - - def __init__(self, funcdatas: typing.Tuple[typing.Tuple[Array, ...], ...], index_name: str, length: Array): - assert isinstance(funcdatas, tuple) and all(isinstance(funcdata, tuple) and all(isinstance(d, Array) for d in funcdata) for funcdata in funcdatas), f'funcdatas={funcdatas!r}' - assert isinstance(index_name, str), f'index_name={index_name}' - assert _isindex(length), f'length={length!r}' - self._funcdatas = funcdatas - self._funcs = tuple(func for func, start, stop, *shape in funcdatas) - self._index_name = index_name - self._index = loop_index(index_name, length) - if any(not func.ndim for func in self._funcs): - raise ValueError('expected an array with at least one axis') - shapes = tuple(Tuple(tuple(shape)) for func, start, stop, *shape in funcdatas) - if any(self._index in shape.arguments for shape in shapes): - raise ValueError('the shape of the function must not depend on the index') - self._invariants, self._dependencies = _dependencies_sans_invariants( - Tuple(tuple(Tuple((start, stop, func)) for func, start, stop, *shape in funcdatas)), self._index) - super().__init__(args=(Tuple(shapes), length, *self._invariants)) - - @cached_property - def _serialized_loop(self): - indices = {d: i for i, d in enumerate(itertools.chain([self._index], self._invariants, self._dependencies))} - return tuple((dep, tuple(map(indices.__getitem__, dep._Evaluable__args))) for dep in self._dependencies) - - # This property is a derivation of `_serialized` where the `Evaluable` - # instances are mapped to the `evalf` methods of the instances. Asserting - # that functions are immutable is difficult and currently - # `types._isimmutable` marks all functions as mutable. Since the - # `types.CacheMeta` machinery asserts immutability of the property, we have - # to resort to a regular `functools.cached_property`. Nevertheless, this - # property should be treated as if it is immutable. - @cached_property - def _serialized_loop_evalf(self): - return tuple((dep.evalf, indices) for dep, indices in self._serialized_loop) - - def evalf(self, shapes, length, *args): - serialized_evalf = self._serialized_loop_evalf - results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)] - with parallel.ctxrange('loop {}'.format(self._index_name), int(length)) as indices: - for index in indices: - values = [numpy.array(index)] - values.extend(args) - values.extend(op_evalf(*[values[i] for i in indices]) for op_evalf, indices in serialized_evalf) - for result, (start, stop, block) in zip(results, values[-1]): - result[..., start:stop] = block - return tuple(results) - - def evalf_withtimes(self, times, shapes, length, *args): - serialized = self._serialized_loop - subtimes = times.setdefault(self, collections.defaultdict(_Stats)) - results = [parallel.shempty(tuple(map(int, shape)), dtype=func.dtype) for func, shape in zip(self._funcs, shapes)] - for index in range(length): - values = [numpy.array(index)] - values.extend(args) - values.extend(op.evalf_withtimes(subtimes, *[values[i] for i in indices]) for op, indices in serialized) - for func, result, (start, stop, block) in zip(self._funcs, results, values[-1]): - with subtimes['concat', func]: - result[..., start:stop] = block - return tuple(results) - - def _node(self, cache, subgraph, times): - if (self, 'tuple') in cache: - return cache[self, 'tuple'] - subcache = {} - for arg in self._invariants: - subcache[arg] = arg._node(cache, subgraph, times) - loopgraph = Subgraph('Loop', subgraph) - subtimes = times.get(self, collections.defaultdict(_Stats)) - concats = [] - for func, start, stop, *shape in self._funcdatas: - concat_kwargs = {'shape[{}]'.format(i): n._node(cache, subgraph, times) for i, n in enumerate(shape)} - concat_kwargs['start'] = start._node(subcache, loopgraph, subtimes) - concat_kwargs['stop'] = stop._node(subcache, loopgraph, subtimes) - concat_kwargs['func'] = func._node(subcache, loopgraph, subtimes) - concats.append(RegularNode('LoopConcatenate', (), concat_kwargs, (type(self).__name__, subtimes['concat', func]), loopgraph)) - cache[self, 'tuple'] = node = TupleNode(tuple(concats), (type(self).__name__, times[self]), subgraph) - return node - - class SearchSorted(Array): '''Find index of evaluable array into sorted numpy array.''' @@ -5061,10 +5042,10 @@ def loop_sum(func, index): func = asarray(func) if not isinstance(index, _LoopIndex): raise TypeError(f'expected _LoopIndex, got {index!r}') - return LoopSum(func, func.shape, index._name, index.length) + return LoopSum(func, func.shape, index.name, index.length) -def _loop_concatenate_data(func, index): +def loop_concatenate(func, index): func = asarray(func) if not isinstance(index, _LoopIndex): raise TypeError(f'expected _LoopIndex, got {index!r}') @@ -5076,20 +5057,8 @@ def _loop_concatenate_data(func, index): offsets = _SizesToOffsets(chunk_sizes) start = Take(offsets, index) stop = Take(offsets, index+1) - return (func, start, stop, *func.shape[:-1], Take(offsets, index.length)) - - -def loop_concatenate(func, index): - funcdata = _loop_concatenate_data(func, index) - return LoopConcatenate(funcdata, index._name, index.length) - - -def loop_concatenate_combined(funcs, index): - unique_funcs = [] - unique_funcs.extend(func for func in funcs if func not in unique_funcs) - unique_func_data = tuple(_loop_concatenate_data(func, index) for func in unique_funcs) - loop = LoopConcatenateCombined(unique_func_data, index._name, index.length) - return tuple(ArrayFromTuple(loop, unique_funcs.index(func), tuple(shape), func.dtype) for func, start, stop, *shape in unique_func_data) + shape = *func.shape[:-1], Take(offsets, index.length) + return LoopConcatenate(func, start, stop, shape, index.name, index.length) @replace diff --git a/tests/test_evaluable.py b/tests/test_evaluable.py index 3a5022c7c..13beaad58 100644 --- a/tests/test_evaluable.py +++ b/tests/test_evaluable.py @@ -593,7 +593,6 @@ def _check(name, op, n_op, *arg_values, hasgrad=True, zerograd=False, ndim=2): _check('loopsum6', lambda: evaluable.loop_sum(evaluable.Guard(evaluable.constant(1) + evaluable.loop_index('index', 4)), evaluable.loop_index('index', 4)) * evaluable.loop_sum(evaluable.loop_index('index', 4), evaluable.loop_index('index', 4)), lambda: numpy.array(60)) _check('loopconcatenate1', lambda a: evaluable.loop_concatenate(a+evaluable.prependaxes(evaluable.astype(evaluable.loop_index('index', 3), float), a.shape), evaluable.loop_index('index', 3)), lambda a: a+numpy.arange(3)[None], ANY(3, 1)) _check('loopconcatenate2', lambda: evaluable.loop_concatenate(evaluable.Elemwise(tuple(types.arraydata(numpy.arange(48).reshape(4, 4, 3)[:, :, a:b]) for a, b in util.pairwise([0, 2, 3])), evaluable.loop_index('index', 2), int), evaluable.loop_index('index', 2)), lambda: numpy.arange(48).reshape(4, 4, 3)) -_check('loopconcatenatecombined', lambda a: evaluable.loop_concatenate_combined([a+evaluable.prependaxes(evaluable.astype(evaluable.loop_index('index', 3), float), a.shape)], evaluable.loop_index('index', 3))[0], lambda a: a+numpy.arange(3)[None], ANY(3, 1), hasgrad=False) _check('legendre', lambda a: evaluable.Legendre(evaluable.asarray(a), 5), lambda a: numpy.moveaxis(numpy.polynomial.legendre.legval(a, numpy.eye(6)), 0, -1), ANY(3, 4, 3)) _check('polyval_1d_p0', lambda c, x: evaluable.Polyval(c, x), poly.eval_outer, POS(1), ANY(4, 1), ndim=1) @@ -935,35 +934,6 @@ def test_loop_concatenate(self): ' ├ %B2\n' ' └ 1\n') - @unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order') - def test_loop_concatenatecombined(self): - i = evaluable.loop_index('i', 2) - f, = evaluable.loop_concatenate_combined([evaluable.InsertAxis(i, evaluable.constant(1))], i) - self.assertEqual(f.asciitree(richoutput=True), - 'SUBGRAPHS\n' - 'A\n' - '└ B = Loop\n' - 'NODES\n' - '%B0 = LoopConcatenate\n' - '├ shape[0] = %A0 = Take; i:; [2,2]\n' - '│ ├ %A1 = _SizesToOffsets; i:3; [0,2]\n' - '│ │ └ %A2 = InsertAxis; i:(2); [1,1]\n' - '│ │ ├ 1\n' - '│ │ └ 2\n' - '│ └ 2\n' - '├ start = %B1 = Take; i:; [0,2]\n' - '│ ├ %A1\n' - '│ └ %B2 = LoopIndex\n' - '│ └ length = 2\n' - '├ stop = %B3 = Take; i:; [0,2]\n' - '│ ├ %A1\n' - '│ └ %B4 = Add; i:; [1,2]\n' - '│ ├ %B2\n' - '│ └ 1\n' - '└ func = %B5 = InsertAxis; i:(1); [0,1]\n' - ' ├ %B2\n' - ' └ 1\n') - class simplify(TestCase): @@ -1105,51 +1075,41 @@ def _simplified(self): t.simplified -class combine_loop_concatenates(TestCase): +class combine_loops(TestCase): def test_same_index(self): i = evaluable.loop_index('i', 3) - A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+1, evaluable.constant(3)), i._name, i.length) - B = evaluable.LoopConcatenate((evaluable.InsertAxis(i, evaluable.constant(2)), i*2, i*2+2, evaluable.constant(6)), i._name, i.length) - actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) - L = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+1, evaluable.constant(3)), (evaluable.InsertAxis(i, evaluable.constant(2)), i*2, i*2+2, evaluable.constant(6))), i._name, i.length) - desired = evaluable.Tuple((evaluable.ArrayFromTuple(L, 0, (evaluable.constant(3),), int, **dict(zip(('_lower', '_upper'), A._intbounds))), evaluable.ArrayFromTuple(L, 1, (evaluable.constant(6),), int, **dict(zip(('_lower', '_upper'), B._intbounds))))) + A = evaluable.loop_concatenate(evaluable.InsertAxis(i, evaluable.constant(1)), i) + B = evaluable.loop_concatenate(evaluable.InsertAxis(i, evaluable.constant(2)), i) + actual = evaluable.Tuple((A, B))._combine_loops(set()) + L = evaluable.LoopTuple((A, B), i.name, i.length) + desired = evaluable.Tuple((evaluable.ArrayFromTuple(L, 0, A.shape, A.dtype, **dict(zip(('_lower', '_upper'), A._intbounds))), evaluable.ArrayFromTuple(L, 1, B.shape, B.dtype, **dict(zip(('_lower', '_upper'), B._intbounds))))) self.assertEqual(actual, desired) def test_different_index(self): i = evaluable.loop_index('i', 3) j = evaluable.loop_index('j', 3) - A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+1, evaluable.constant(3)), i._name, i.length) - B = evaluable.LoopConcatenate((evaluable.InsertAxis(j, evaluable.constant(1)), j, j+1, evaluable.constant(3)), j._name, j.length) - actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) - L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+evaluable.constant(1), evaluable.constant(3)),), i._name, i.length) - L2 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(j, evaluable.constant(1)), j, j+evaluable.constant(1), evaluable.constant(3)),), j._name, j.length) - desired = evaluable.Tuple((evaluable.ArrayFromTuple(L1, 0, (evaluable.constant(3),), int, **dict(zip(('_lower', '_upper'), A._intbounds))), evaluable.ArrayFromTuple(L2, 0, (evaluable.constant(3),), int, **dict(zip(('_lower', '_upper'), B._intbounds))))) + A = evaluable.loop_concatenate(evaluable.InsertAxis(i, evaluable.constant(1)), i) + B = evaluable.loop_concatenate(evaluable.InsertAxis(j, evaluable.constant(1)), j) + desired = evaluable.Tuple((A, B)) + actual = desired._combine_loops(set()) self.assertEqual(actual, desired) def test_nested_invariant(self): i = evaluable.loop_index('i', 3) - A = evaluable.LoopConcatenate((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+1, evaluable.constant(3)), i._name, i.length) - B = evaluable.LoopConcatenate((A, i*3, i*3+3, evaluable.constant(9)), i._name, i.length) - actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) - L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i, evaluable.constant(1)), i, i+1, evaluable.constant(3)),), i._name, i.length) - A_ = evaluable.ArrayFromTuple(L1, 0, (evaluable.constant(3),), int, **dict(zip(('_lower', '_upper'), A._intbounds))) - L2 = evaluable.LoopConcatenateCombined(((A_, i*3, i*3+3, evaluable.constant(9)),), i._name, i.length) - self.assertIn(A_, L2._Evaluable__args) - desired = evaluable.Tuple((A_, evaluable.ArrayFromTuple(L2, 0, (evaluable.constant(9),), int, **dict(zip(('_lower', '_upper'), B._intbounds))))) + A = evaluable.loop_concatenate(evaluable.InsertAxis(i, evaluable.constant(1)), i) + B = evaluable.loop_concatenate(A, i) + desired = evaluable.Tuple((A, B)) + actual = desired._combine_loops(set()) self.assertEqual(actual, desired) def test_nested_variant(self): i = evaluable.loop_index('i', 3) j = evaluable.loop_index('j', 3) - A = evaluable.LoopConcatenate((evaluable.InsertAxis(i+j, evaluable.constant(1)), i, i+1, evaluable.constant(3)), i._name, i.length) - B = evaluable.LoopConcatenate((A, j*3, j*3+3, evaluable.constant(9)), j._name, j.length) - actual = evaluable.Tuple((A, B))._combine_loop_concatenates(set()) - L1 = evaluable.LoopConcatenateCombined(((evaluable.InsertAxis(i+j, evaluable.constant(1)), i, i+1, evaluable.constant(3)),), i._name, i.length) - A_ = evaluable.ArrayFromTuple(L1, 0, (evaluable.constant(3),), int, **dict(zip(('_lower', '_upper'), A._intbounds))) - L2 = evaluable.LoopConcatenateCombined(((A_, j*3, j*3+3, evaluable.constant(9)),), j._name, j.length) - self.assertNotIn(A_, L2._Evaluable__args) - desired = evaluable.Tuple((A_, evaluable.ArrayFromTuple(L2, 0, (evaluable.constant(9),), int, **dict(zip(('_lower', '_upper'), B._intbounds))))) + A = evaluable.loop_concatenate(evaluable.InsertAxis(i+j, evaluable.constant(1)), i) + B = evaluable.loop_concatenate(A, j) + desired = evaluable.Tuple((A, B)) + actual = desired._combine_loops(set()) self.assertEqual(actual, desired)