Skip to content

Commit

Permalink
broadcast_to chunks & __attrs_post_init__
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Sep 11, 2023
1 parent 89e9d97 commit 6d96a76
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _is_none_slice(x):
return isinstance(x, slice) and x == slice(None)


@attrs(frozen=False, repr=False)
@attrs(frozen=False, init=True, repr=False)
class IrrepsArray:
r"""Array with a representation of rotations.
Expand Down Expand Up @@ -82,17 +82,30 @@ class IrrepsArray:
)
_chunks: Optional[List[Optional[jnp.ndarray]]] = attrib(default=None, kw_only=True)

def __post_init__(self):
def __attrs_post_init__(self):
if hasattr(self.array, "shape"):
if self.array.shape[-1] != self.irreps.dim:
raise ValueError(
f"IrrepsArray: Array shape {self.array.shape} incompatible with irreps {self.irreps}. "
f"{self.array.shape[-1]} != {self.irreps.dim}"
)
if self.zero_flags is not None:
if len(self.zero_flags) != len(self.irreps):
if self._chunks is not None:
for (mul, ir), chunk in zip(self.irreps, self._chunks):
if chunk is not None and chunk.shape[-2:] != (mul, ir.dim):
raise ValueError(
f"IrrepsArray: chunk shape {chunk.shape} incompatible with mul={mul} and ir.dim={ir.dim}"
)

if self._zero_flags is not None:
if len(self._zero_flags) != len(self.irreps):
raise ValueError(
f"IrrepsArray: len(zero_flags) != len(irreps), {len(self.zero_flags)} != {len(self.irreps)}"
f"IrrepsArray: len(zero_flags) != len(irreps), {len(self._zero_flags)} != {len(self.irreps)}"
)

if self._chunks is not None:
if len(self._chunks) != len(self.irreps):
raise ValueError(
f"IrrepsArray: len(chunks) != len(irreps), {len(self._chunks)} != {len(self.irreps)}"
)

@staticmethod
Expand Down Expand Up @@ -1079,16 +1092,17 @@ def rechunk(self, irreps: IntoIrreps) -> "IrrepsArray":
return self

if len(self.irreps) == 0:
zero_flags = []
zero_flags = np.empty((0,), dtype=bool)
else:
zero_flags = np.concatenate(
[
z * np.ones(mul * ir.dim, dtype=bool)
for z, (mul, ir) in zip(self.zero_flags, self.irreps)
]
)
zero_flags = [bool(np.all(zero_flags[s])) for s in irreps.slices()]
zero_flags = [bool(np.all(zero_flags[s])) for s in irreps.slices()]

# TODO: split and merge chunks
return IrrepsArray(irreps, self.array, zero_flags=zero_flags)

def broadcast_to(self, shape) -> "IrrepsArray":
Expand All @@ -1099,7 +1113,13 @@ def broadcast_to(self, shape) -> "IrrepsArray":
assert shape[-1] == self.irreps.dim or shape[-1] == -1
leading_shape = shape[:-1]
array = jnp.broadcast_to(self.array, leading_shape + (self.irreps.dim,))
return IrrepsArray(self.irreps, array, zero_flags=self.zero_flags)
chunks = [
None if x is None else jnp.broadcast_to(x, leading_shape + x.shape[-2:])
for x in self.chunks
]
return IrrepsArray(
self.irreps, array, zero_flags=self.zero_flags, chunks=chunks
)


# We purposefully do not register zero_flags
Expand Down

0 comments on commit 6d96a76

Please sign in to comment.