diff --git a/nutils/evaluable.py b/nutils/evaluable.py index e8f981e5a..3492d09b8 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -4356,16 +4356,17 @@ def _intbounds_impl(self): class LoopConcatenate(Loop, Array): - def __init__(self, func: Array, start: Array, stop: Array, shape: typing.Tuple[Array, ...], index_name: str, length: Array): + def __init__(self, func: Array, start: Array, stop: Array, concat_length: Array, index_name: str, length: Array): assert isinstance(func, Array), f'func={func}' assert _isindex(start), f'start={start}' assert _isindex(stop), f'stop={stop}' - assert isinstance(shape, tuple) and all(map(_isindex, shape)), f'shape={shape}' + assert _isindex(concat_length), f'concat_length={concat_length}' self.func = func self.start = start self.stop = stop if not self.func.ndim: raise ValueError('expected an array with at least one axis') + shape = *func.shape[:-1], concat_length super().__init__(init_arg=Tuple(shape), body_arg=Tuple((func, start, stop)), index_name=index_name, length=length, shape=shape, dtype=func.dtype) def evalf_loop_init(self, shape): @@ -4942,8 +4943,8 @@ def loop_concatenate(func, index): offsets = _SizesToOffsets(chunk_sizes) start = Take(offsets, index) stop = Take(offsets, index+1) - shape = *func.shape[:-1], Take(offsets, index.length) - return LoopConcatenate(func, start, stop, shape, index.name, index.length) + concat_length = Take(offsets, index.length) + return LoopConcatenate(func, start, stop, concat_length, index.name, index.length) @util.shallow_replace