Skip to content

Commit

Permalink
replace Loop.{init,body}_arg with {init,body}_args
Browse files Browse the repository at this point in the history
To reduce the unnecessary usage of `evaluable.Tuple`, this patch replaces the
loop init and body arg (singular) with init and body args (plural), similar to
the `args` parameter of `Evaluable`.
  • Loading branch information
joostvanzwieten committed May 6, 2024
1 parent a329e58 commit 811c221
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4142,19 +4142,19 @@ class Loop(Evaluable):
* method ``evalf_loop_body(output, body_arg)``.
'''

def __init__(self, loop_id: _LoopId, length: Array, init_arg: Evaluable, body_arg: Evaluable, *args, **kwargs):
def __init__(self, loop_id: _LoopId, length: Array, init_args: typing.Tuple[Evaluable, ...], body_args: typing.Tuple[Evaluable, ...], *args, **kwargs):
assert isinstance(loop_id, _LoopId), f'loop_id={loop_id!r}'
assert isinstance(length, Array), f'length={length!r}'
assert isinstance(init_arg, Evaluable), f'init_arg={init_arg!r}'
assert isinstance(body_arg, Evaluable), f'body_arg={init_arg!r}'
assert isinstance(init_args, tuple) and all(isinstance(arg, Evaluable) for arg in init_args), f'init_args={init_args!r}'
assert isinstance(body_args, tuple) and all(isinstance(arg, Evaluable) for arg in body_args), f'body_args={body_args!r}'
self.loop_id = loop_id
self.length = length
self.index = _LoopIndex(loop_id, length)
self.init_arg = init_arg
self.body_arg = body_arg
if self.index in init_arg.arguments:
self.init_args = init_args
self.body_args = body_args
if any(self.index in arg.arguments for arg in init_args):
raise ValueError('the loop initialization arguments must not depend on the index')
super().__init__(args=(length, init_arg, body_arg), *args, **kwargs)
super().__init__(args=(length, *init_args, *body_args), *args, **kwargs)

@cached_property
def _loop_block_id(self):
Expand All @@ -4172,7 +4172,7 @@ def _node(self, cache, subgraph, times, unique_loop_ids):
return cached

# Populate the `cache` with objects that do not depend on `self.index`.
stack = [self.init_arg, self.body_arg]
stack = [*self.init_args, *self.body_args]
while stack:
func = stack.pop()
if self.index in func.arguments:
Expand All @@ -4199,8 +4199,8 @@ def arguments(self):
@property
def _loops(self):
deps = util.IDSet([self])
deps |= self.init_arg._loops
deps |= self.body_arg._loops
for arg in itertools.chain(self.init_args, self.body_args):
deps |= arg._loops
return deps.view()


Expand All @@ -4210,7 +4210,7 @@ def __init__(self, func: Array, shape: typing.Tuple[Array, ...], loop_id: _LoopI
assert isinstance(func, Array) and func.dtype != bool, f'func={func!r}'
assert func.ndim == len(shape)
self.func = func
super().__init__(init_arg=Tuple(shape), body_arg=func, loop_id=loop_id, length=length, shape=shape, dtype=func.dtype)
super().__init__(init_args=shape, body_args=(func,), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype)

def _compile(self, builder):
out, alloc_block_id = builder.new_empty_array_for_evaluable(self)
Expand Down Expand Up @@ -4335,7 +4335,7 @@ def __init__(self, func: Array, start: Array, stop: Array, concat_length: Array,
if not self.func.ndim:
raise ValueError('expected an array with at least one axis')
shape = *func.shape[:-1], concat_length
super().__init__(init_arg=Tuple(shape), body_arg=Tuple((func, start, stop)), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype)
super().__init__(init_args=shape, body_args=(func, start, stop), loop_id=loop_id, length=length, shape=shape, dtype=func.dtype)

def _compile(self, builder):
out, alloc_block_id = builder.new_empty_array_for_evaluable(self)
Expand Down

0 comments on commit 811c221

Please sign in to comment.