diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index 64113c2..8a5ecdf 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -590,6 +590,7 @@ def axis_angle_to_matrix(axis, angle): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` """ + angle = jnp.asarray(angle) axis, angle = jnp.broadcast_arrays(axis, angle[..., None]) alpha, beta = xyz_to_angles(axis) R = angles_to_matrix(alpha, beta, jnp.zeros_like(beta))