diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 1401d34..66c2fc4 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -1116,8 +1116,96 @@ def rechunk(self, irreps: IntoIrreps) -> "IrrepsArray": ) zero_flags = [bool(np.all(zero_flags[s])) for s in irreps.slices()] - # TODO: split and merge chunks - return IrrepsArray(irreps, self.array, zero_flags=zero_flags) + new_chunks = None + if self._chunks is not None: + leading_shape = self.shape[:-1] + + new_chunks = [] + current_array = 0 + + while len(new_chunks) < len(irreps) and irreps[len(new_chunks)].mul == 0: + new_chunks.append(None) + + for mul_ir, y in zip(self.irreps, self.chunks): + mul, _ = mul_ir + + while mul > 0: + if isinstance(current_array, int): + current_mul = current_array + else: + current_mul = current_array.shape[-2] + + needed_mul = irreps[len(new_chunks)].mul - current_mul + + if mul <= needed_mul: + x = y + m = mul + mul = 0 + elif mul > needed_mul: + if y is None: + x = None + else: + x, y = jnp.split(y, [needed_mul], axis=-2) + m = needed_mul + mul -= needed_mul + + if x is None: + if isinstance(current_array, int): + current_array += m + else: + current_array = jnp.concatenate( + [ + current_array, + jnp.zeros( + leading_shape + (m, mul_ir.ir.dim), self.dtype + ), + ], + axis=-2, + ) + else: + if isinstance(current_array, int): + if current_array == 0: + current_array = x + else: + current_array = jnp.concatenate( + [ + jnp.zeros( + leading_shape + + (current_array, mul_ir.ir.dim), + self.dtype, + ), + x, + ], + axis=-2, + ) + else: + current_array = jnp.concatenate([current_array, x], axis=-2) + + if isinstance(current_array, int): + if current_array == irreps[len(new_chunks)].mul: + new_chunks.append(None) + current_array = 0 + else: + if current_array.shape[-2] == irreps[len(new_chunks)].mul: + new_chunks.append(current_array) + current_array = 0 + + while ( + len(new_chunks) < len(irreps) + and irreps[len(new_chunks)].mul == 0 + ): + new_chunks.append(None) + + assert current_array == 0 + + assert len(new_chunks) == len(irreps) + for (mul, ir), x, z in zip(irreps, new_chunks, zero_flags): + if z: + assert x is None + else: + assert x.shape[-2:] == (mul, ir.dim) + + return IrrepsArray(irreps, self.array, zero_flags=zero_flags, chunks=new_chunks) def broadcast_to(self, shape) -> "IrrepsArray": """Broadcast the array to a new shape."""