diff --git a/e3nn_jax/_src/so3.py b/e3nn_jax/_src/so3.py index aa6b6f4..27ee0db 100644 --- a/e3nn_jax/_src/so3.py +++ b/e3nn_jax/_src/so3.py @@ -13,9 +13,9 @@ def change_basis_real_to_complex(l: int) -> np.ndarray: for m in range(1, l + 1): q[l + m, l + abs(m)] = (-1) ** m / np.sqrt(2) q[l + m, l - abs(m)] = 1j * (-1) ** m / np.sqrt(2) - return ( - -1j - ) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real + + # Added factor of 1j**l to make the Clebsch-Gordan coefficients real + return (-1j) ** l * q def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: diff --git a/e3nn_jax/_src/so3_test.py b/e3nn_jax/_src/so3_test.py index 72a9f63..ac3838d 100644 --- a/e3nn_jax/_src/so3_test.py +++ b/e3nn_jax/_src/so3_test.py @@ -24,6 +24,12 @@ def test_clebsch_gordan_symmetry(): clebsch_gordan(1, 2, 3), jnp.swapaxes(jnp.swapaxes(clebsch_gordan(2, 3, 1), 0, 2), 1, 2), ) + assert jnp.allclose( + clebsch_gordan(3, 2, 4), -jnp.swapaxes(clebsch_gordan(3, 4, 2), 1, 2) + ) + assert jnp.allclose( + clebsch_gordan(2, 3, 4), -jnp.swapaxes(clebsch_gordan(4, 3, 2), 0, 2) + ) def unique_triplets(lmax):