Skip to content

Commit

Permalink
remove shape[:-1] from LoopConcat constructor args
Browse files Browse the repository at this point in the history
The shape of `LoopConcatenate` is used in the loop initialization to allocate
the output array. The length of the concatenation axis is the result of
concatenating and summing (`_SizesToOffset`) the lengths of the same axis of
the concatenation value. Under certain conditions the concatenation length
can be simplified. For this to work, the concatenation length must be a
constructor argument of `LoopConcatenate`, otherwise the length is not picked
up by `deep_replace_property`.

For the same reason the entire shape of `LoopConcatenate` was passed as
constructor argument. Since the shape of an array is very likely already
optimized --- the same is true for `LoopConcatenate.shape`! --- this patch
removes the shape from the constructor of `LoopConcatenate`, except for the
concatenation axis.
  • Loading branch information
joostvanzwieten committed Mar 21, 2024
1 parent 51d9b6b commit aaf54c1
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4356,16 +4356,17 @@ def _intbounds_impl(self):

class LoopConcatenate(Loop, Array):

def __init__(self, func: Array, start: Array, stop: Array, shape: typing.Tuple[Array, ...], index_name: str, length: Array):
def __init__(self, func: Array, start: Array, stop: Array, concat_length: Array, index_name: str, length: Array):
assert isinstance(func, Array), f'func={func}'
assert _isindex(start), f'start={start}'
assert _isindex(stop), f'stop={stop}'
assert isinstance(shape, tuple) and all(map(_isindex, shape)), f'shape={shape}'
assert _isindex(concat_length), f'concat_length={concat_length}'
self.func = func
self.start = start
self.stop = stop
if not self.func.ndim:
raise ValueError('expected an array with at least one axis')
shape = *func.shape[:-1], concat_length
super().__init__(init_arg=Tuple(shape), body_arg=Tuple((func, start, stop)), index_name=index_name, length=length, shape=shape, dtype=func.dtype)

def evalf_loop_init(self, shape):
Expand Down Expand Up @@ -4942,8 +4943,8 @@ def loop_concatenate(func, index):
offsets = _SizesToOffsets(chunk_sizes)
start = Take(offsets, index)
stop = Take(offsets, index+1)
shape = *func.shape[:-1], Take(offsets, index.length)
return LoopConcatenate(func, start, stop, shape, index.name, index.length)
concat_length = Take(offsets, index.length)
return LoopConcatenate(func, start, stop, concat_length, index.name, index.length)


@util.shallow_replace
Expand Down

0 comments on commit aaf54c1

Please sign in to comment.