Skip to content

Commit

Permalink
add back complicated algo to rechunk
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Sep 13, 2023
1 parent 0ee3c80 commit 55709c3
Showing 1 changed file with 90 additions and 2 deletions.
92 changes: 90 additions & 2 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,96 @@ def rechunk(self, irreps: IntoIrreps) -> "IrrepsArray":
)
zero_flags = [bool(np.all(zero_flags[s])) for s in irreps.slices()]

# TODO: split and merge chunks
return IrrepsArray(irreps, self.array, zero_flags=zero_flags)
new_chunks = None
if self._chunks is not None:
leading_shape = self.shape[:-1]

new_chunks = []
current_array = 0

while len(new_chunks) < len(irreps) and irreps[len(new_chunks)].mul == 0:
new_chunks.append(None)

for mul_ir, y in zip(self.irreps, self.chunks):
mul, _ = mul_ir

while mul > 0:
if isinstance(current_array, int):
current_mul = current_array
else:
current_mul = current_array.shape[-2]

needed_mul = irreps[len(new_chunks)].mul - current_mul

if mul <= needed_mul:
x = y
m = mul
mul = 0
elif mul > needed_mul:
if y is None:
x = None
else:
x, y = jnp.split(y, [needed_mul], axis=-2)
m = needed_mul
mul -= needed_mul

if x is None:
if isinstance(current_array, int):
current_array += m
else:
current_array = jnp.concatenate(
[
current_array,
jnp.zeros(
leading_shape + (m, mul_ir.ir.dim), self.dtype
),
],
axis=-2,
)
else:
if isinstance(current_array, int):
if current_array == 0:
current_array = x
else:
current_array = jnp.concatenate(
[
jnp.zeros(
leading_shape
+ (current_array, mul_ir.ir.dim),
self.dtype,
),
x,
],
axis=-2,
)
else:
current_array = jnp.concatenate([current_array, x], axis=-2)

if isinstance(current_array, int):
if current_array == irreps[len(new_chunks)].mul:
new_chunks.append(None)
current_array = 0
else:
if current_array.shape[-2] == irreps[len(new_chunks)].mul:
new_chunks.append(current_array)
current_array = 0

while (
len(new_chunks) < len(irreps)
and irreps[len(new_chunks)].mul == 0
):
new_chunks.append(None)

assert current_array == 0

assert len(new_chunks) == len(irreps)
for (mul, ir), x, z in zip(irreps, new_chunks, zero_flags):
if z:
assert x is None
else:
assert x.shape[-2:] == (mul, ir.dim)

return IrrepsArray(irreps, self.array, zero_flags=zero_flags, chunks=new_chunks)

def broadcast_to(self, shape) -> "IrrepsArray":
"""Broadcast the array to a new shape."""
Expand Down

0 comments on commit 55709c3

Please sign in to comment.