Skip to content

Commit

Permalink
add _dependencies_sans_invariants
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Mar 11, 2021
1 parent 0a27364 commit ac733b4
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3329,10 +3329,7 @@ def __init__(self, funcdata:prepare_funcdata, index:types.strict[Argument], leng
self.func = funcdata[0]
self.index = index
self.length = length
self._invariants = []
self._dependencies = []
_populate_dependencies_sans_invariants(self.func, index, self._invariants, self._dependencies, set())
assert (self._dependencies or self._invariants or [index])[-1] == self.func
self._invariants, self._dependencies = _dependencies_sans_invariants(self.func, index)
axes = tuple(Axis(axis.length) if isinstance(axis, Sparse) else axis for axis in self.func._axes)
super().__init__(args=(shape, length, *self._invariants), shape=axes, dtype=self.func.dtype)

Expand Down Expand Up @@ -3521,11 +3518,8 @@ def __init__(self, funcdatas:types.tuple[asarrays], index:types.strict[Argument]
raise ValueError('the length of the loop must not depend on the index')
self._index = index
self._length = length
self._invariants = []
self._dependencies = []
result = Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas])
_populate_dependencies_sans_invariants(result, index, self._invariants, self._dependencies, set())
assert (self._dependencies or self._invariants)[-1] == result
self._invariants, self._dependencies = _dependencies_sans_invariants(
Tuple([Tuple([start, stop, func]) for func, start, stop, *shape in funcdatas]), index)
super().__init__(args=(Tuple(shapes), length, *self._invariants))

@property
Expand Down Expand Up @@ -3641,13 +3635,18 @@ def _inflate_scalar(arg, shape):
def _isunique(array):
return numpy.unique(array).size == array.size

def _dependencies_sans_invariants(func, arg):
invariants = []
dependencies = []
_populate_dependencies_sans_invariants(func, arg, invariants, dependencies, {arg})
assert (dependencies or invariants or [arg])[-1] == func
return tuple(invariants), tuple(dependencies)

def _populate_dependencies_sans_invariants(func, arg, invariants, dependencies, cache):
if func in cache:
return
cache.add(func)
if func == arg:
pass
elif arg in func.arguments:
if arg in func.arguments:
for child in func._Evaluable__args:
_populate_dependencies_sans_invariants(child, arg, invariants, dependencies, cache)
dependencies.append(func)
Expand Down

0 comments on commit ac733b4

Please sign in to comment.