From ac733b405111c1404a67cfd802ba5b9a05b651b2 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 29 Jan 2021 11:13:43 +0100 Subject: [PATCH] add _dependencies_sans_invariants --- nutils/evaluable.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 8ee3090f0..93e7a79bf 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -3329,10 +3329,7 @@ def __init__(self, funcdata:prepare_funcdata, index:types.strict[Argument], leng self.func = funcdata[0] self.index = index self.length = length - self._invariants = [] - self._dependencies = [] - _populate_dependencies_sans_invariants(self.func, index, self._invariants, self._dependencies, set()) - assert (self._dependencies or self._invariants or [index])[-1] == self.func + self._invariants, self._dependencies = _dependencies_sans_invariants(self.func, index) axes = tuple(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in self.func._axes) super().__init__(args=(shape, length, *self._invariants), shape=axes, dtype=self.func.dtype) @@ -3521,11 +3518,8 @@ def __init__(self, funcdatas:types.tuple[asarrays], index:types.strict[Argument] raise ValueError('the length of the loop must not depend on the index') self._index = index self._length = length - self._invariants = [] - self._dependencies = [] - result = Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas]) - _populate_dependencies_sans_invariants(result, index, self._invariants, self._dependencies, set()) - assert (self._dependencies or self._invariants)[-1] == result + self._invariants, self._dependencies = _dependencies_sans_invariants( + Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas]), index) super().__init__(args=(Tuple(shapes), length, *self._invariants)) @property @@ -3641,13 +3635,18 @@ def _inflate_scalar(arg, shape): def _isunique(array): return numpy.unique(array).size == array.size +def _dependencies_sans_invariants(func, arg): + invariants = [] + dependencies = [] + _populate_dependencies_sans_invariants(func, arg, invariants, dependencies, {arg}) + assert (dependencies or invariants or [arg])[-1] == func + return tuple(invariants), tuple(dependencies) + def _populate_dependencies_sans_invariants(func, arg, invariants, dependencies, cache): if func in cache: return cache.add(func) - if func == arg: - pass - elif arg in func.arguments: + if arg in func.arguments: for child in func._Evaluable__args: _populate_dependencies_sans_invariants(child, arg, invariants, dependencies, cache) dependencies.append(func)