diff --git a/nutils/evaluable.py b/nutils/evaluable.py index d6e302b90..e39a6dbba 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -4142,19 +4142,19 @@ class Loop(Evaluable): * method ``evalf_loop_body(output, body_arg)``. ''' - def __init__(self, loop_id: _LoopId, length: Array, init_arg: Evaluable, body_arg: Evaluable, *args, **kwargs): + def __init__(self, loop_id: _LoopId, length: Array, init_args: typing.Tuple[Evaluable, ...], body_args: typing.Tuple[Evaluable, ...], *args, **kwargs): assert isinstance(loop_id, _LoopId), f'loop_id={loop_id!r}' assert isinstance(length, Array), f'length={length!r}' - assert isinstance(init_arg, Evaluable), f'init_arg={init_arg!r}' - assert isinstance(body_arg, Evaluable), f'body_arg={init_arg!r}' + assert isinstance(init_args, tuple) and all(isinstance(arg, Evaluable) for arg in init_args), f'init_args={init_args!r}' + assert isinstance(body_args, tuple) and all(isinstance(arg, Evaluable) for arg in body_args), f'body_args={body_args!r}' self.loop_id = loop_id self.length = length self.index = _LoopIndex(loop_id, length) - self.init_arg = init_arg - self.body_arg = body_arg - if self.index in init_arg.arguments: + self.init_args = init_args + self.body_args = body_args + if any(self.index in arg.arguments for arg in init_args): raise ValueError('the loop initialization arguments must not depend on the index') - super().__init__(args=(length, init_arg, body_arg), *args, **kwargs) + super().__init__(args=(length, *init_args, *body_args), *args, **kwargs) @cached_property def _loop_block_id(self): @@ -4172,7 +4172,7 @@ def _node(self, cache, subgraph, times, unique_loop_ids): return cached # Populate the `cache` with objects that do not depend on `self.index`. - stack = [self.init_arg, self.body_arg] + stack = [*self.init_args, *self.body_args] while stack: func = stack.pop() if self.index in func.arguments: @@ -4199,8 +4199,8 @@ def arguments(self): @property def _loops(self): deps = util.IDSet([self]) - deps |= self.init_arg._loops - deps |= self.body_arg._loops + for arg in itertools.chain(self.init_args, self.body_args): + deps |= arg._loops return deps.view() @@ -4210,7 +4210,7 @@ def __init__(self, func: Array, shape: typing.Tuple[Array, ...], loop_id: _LoopI assert isinstance(func, Array) and func.dtype != bool, f'func={func!r}' assert func.ndim == len(shape) self.func = func - super().__init__(init_arg=Tuple(shape), body_arg=func, loop_id=loop_id, length=length, shape=shape, dtype=func.dtype) + super().__init__(init_args=shape, body_args=(func,), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype) def _compile(self, builder): out, alloc_block_id = builder.new_empty_array_for_evaluable(self) @@ -4335,7 +4335,7 @@ def __init__(self, func: Array, start: Array, stop: Array, concat_length: Array, if not self.func.ndim: raise ValueError('expected an array with at least one axis') shape = *func.shape[:-1], concat_length - super().__init__(init_arg=Tuple(shape), body_arg=Tuple((func, start, stop)), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype) + super().__init__(init_args=shape, body_args=(func, start, stop), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype) def _compile(self, builder): out, alloc_block_id = builder.new_empty_array_for_evaluable(self)