Skip to content

Commit

Permalink
more chunks=
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Sep 13, 2023
1 parent cc5551f commit 0ee3c80
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import operator
import warnings
from typing import Any, Callable, List, Optional, Tuple, Union
from attr import attrs, attrib

import jax
import jax.numpy as jnp
import jax.scipy
import numpy as np
from attr import attrib, attrs
from jax.tree_util import tree_map

import e3nn_jax as e3nn
from e3nn_jax import Irreps
Expand Down Expand Up @@ -278,7 +279,12 @@ def eq(mul: int, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return IrrepsArray(self.irreps, self.array == other)

def __neg__(self: "IrrepsArray") -> "IrrepsArray":
return IrrepsArray(self.irreps, -self.array, zero_flags=self.zero_flags)
return IrrepsArray(
self.irreps,
-self.array,
zero_flags=self.zero_flags,
chunks=tree_map(lambda x: -x, self._chunks),
)

def __add__(
self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray, float, int]
Expand Down Expand Up @@ -386,14 +392,11 @@ def __mul__(
f"IrrepsArray({self.irreps}, shape={self.shape}) * scalar(shape={other.shape}) is not equivariant."
)

chunks = None
if self._chunks is not None:
chunks = [
x * other[..., None] if x is not None else None for x in self._chunks
]

return IrrepsArray(
self.irreps, self.array * other, zero_flags=self.zero_flags, chunks=chunks
self.irreps,
self.array * other,
zero_flags=self.zero_flags,
chunks=tree_map(lambda x: x * other[..., None], self._chunks),
)

def __rmul__(
Expand Down Expand Up @@ -433,14 +436,11 @@ def __truediv__(
f"IrrepsArray({self.irreps}, shape={self.shape}) / scalar(shape={other.shape}) is not equivariant."
)

chunks = None
if self._chunks is not None:
chunks = [
x / other[..., None] if x is not None else None for x in self._chunks
]

return IrrepsArray(
self.irreps, self.array / other, zero_flags=self.zero_flags, chunks=chunks
self.irreps,
self.array / other,
zero_flags=self.zero_flags,
chunks=tree_map(lambda x: x / other[..., None], self._chunks),
)

def __rtruediv__(
Expand All @@ -462,13 +462,21 @@ def __rtruediv__(

def __pow__(self, exponent) -> "IrrepsArray": # noqa: D105
if all(ir == "0e" for _, ir in self.irreps):
return IrrepsArray(self.irreps, self.array**exponent)
return IrrepsArray(
self.irreps,
self.array**exponent,
chunks=tree_map(lambda x: x**exponent, self._chunks),
)

if exponent % 1.0 == 0.0 and self.irreps.lmax == 0:
irreps = self.irreps
if exponent % 2.0 == 0.0:
irreps = [(mul, "0e") for mul, ir in self.irreps]
return IrrepsArray(irreps, array=self.array**exponent)
return IrrepsArray(
irreps,
array=self.array**exponent,
chunks=tree_map(lambda x: x**exponent, self._chunks),
)

raise ValueError(
f"IrrepsArray({self.irreps}, shape={self.shape}) ** scalar is not equivariant."
Expand Down Expand Up @@ -629,7 +637,7 @@ def reshape(self, shape) -> "IrrepsArray":
self.irreps,
self.array.reshape(shape[:-1] + (self.irreps.dim,)),
zero_flags=self.zero_flags,
chunks=jax.tree_util.tree_map(
chunks=tree_map(
lambda x: x.reshape(shape[:-1] + x.shape[-2:]), self._chunks
),
)
Expand All @@ -647,7 +655,7 @@ def astype(self, dtype) -> "IrrepsArray":
irreps=self.irreps,
array=self.array.astype(dtype),
zero_flags=self.zero_flags,
chunks=jax.tree_util.tree_map(lambda x: x.astype(dtype), self._chunks),
chunks=tree_map(lambda x: x.astype(dtype), self._chunks),
)

def remove_nones(self) -> "IrrepsArray":
Expand Down Expand Up @@ -1018,15 +1026,15 @@ def transform_by_angles(
}
if inverse:
D = {ir: jnp.swapaxes(D[ir], -2, -1) for ir in D}
new_list = [
new_chunks = [
jnp.reshape(
jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim)
)
if x is not None
else None
for (mul, ir), x in zip(self.irreps, self.chunks)
]
return e3nn.from_chunks(self.irreps, new_list, self.shape[:-1], self.dtype)
return e3nn.from_chunks(self.irreps, new_chunks, self.shape[:-1], self.dtype)

def transform_by_quaternion(self, q: jnp.ndarray, k: int = 0) -> "IrrepsArray":
r"""Rotate data by a rotation given by a quaternion.
Expand Down

0 comments on commit 0ee3c80

Please sign in to comment.