From 23757897fa4f2503a596bcf8e4cc3a8b415b6bde Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 27 Oct 2023 14:16:48 +0200 Subject: [PATCH] add missing asarray in axis_angle_to_matrix --- e3nn_jax/_src/rotation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index 64113c2d..8a5ecdfe 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -590,6 +590,7 @@ def axis_angle_to_matrix(axis, angle): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` """ + angle = jnp.asarray(angle) axis, angle = jnp.broadcast_arrays(axis, angle[..., None]) alpha, beta = xyz_to_angles(axis) R = angles_to_matrix(alpha, beta, jnp.zeros_like(beta))