Skip to content

Commit

Permalink
Add apply() for SO3Signal.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Dec 5, 2024
1 parent 0f879f0 commit 08a7e81
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
4 changes: 4 additions & 0 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
def __truediv__(self, other: float) -> "SO3Signal":
return self * (1 / other)

def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":
"""Apply a pointwise function to the signal."""
return SO3Signal(self.s2_signals.apply(func))

def vmap_over_batch_dims(
self, func: Callable[..., jnp.ndarray]
) -> Callable[..., jnp.ndarray]:
Expand Down
19 changes: 19 additions & 0 deletions tests/_src/so3grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,22 @@ def test_argmax(seed: int):
R_argmax, _ = sig.argmax()

assert jnp.allclose(func(R_argmax), func(R_argmax_expected), rtol=1e-2)


def test_apply():
sig = SO3Signal.from_function(
lambda R: jnp.trace(R @ R),
res_beta=40,
res_alpha=39,
res_theta=40,
quadrature="gausslegendre",
)
sig_applied = sig.apply(jnp.exp)
sig_expected = SO3Signal.from_function(
lambda R: jnp.exp(jnp.trace(R @ R)),
res_beta=40,
res_alpha=39,
res_theta=40,
quadrature="gausslegendre",
)
assert jnp.allclose(sig_applied.grid_values, sig_expected.grid_values)

0 comments on commit 08a7e81

Please sign in to comment.