Skip to content

Commit

Permalink
optimize wigner D
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Oct 17, 2023
1 parent 2d1166d commit 6f3a826
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions e3nn_jax/_src/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,24 +1053,19 @@ def rot_y(phi):
return M

R = []
if l < len(Jd):
if a is not None:
R += [rot_y(a)]
if b is not None:
if a is not None:
R += [rot_y(a)]

if b is not None:
if l < len(Jd):
J = Jd[l]
R += [J @ rot_y(b) @ J]
if c is not None:
R += [rot_y(c)]
else:
X = generators(l)
exp = jax.scipy.linalg.expm

if a is not None:
R += [exp(a * X[1])]
if b is not None:
R += [exp(b * X[0])]
if c is not None:
R += [exp(c * X[1])]
else:
X = generators(l)
R += [jax.scipy.linalg.expm(b * X[0])]

if c is not None:
R += [rot_y(c)]

if len(R) == 0:
return jnp.eye(2 * l + 1)
Expand Down

0 comments on commit 6f3a826

Please sign in to comment.