Skip to content

Commit

Permalink
add missing asarray in axis_angle_to_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Oct 27, 2023
1 parent 6f3a826 commit 2375789
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions e3nn_jax/_src/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def axis_angle_to_matrix(axis, angle):
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)`
"""
angle = jnp.asarray(angle)
axis, angle = jnp.broadcast_arrays(axis, angle[..., None])
alpha, beta = xyz_to_angles(axis)
R = angles_to_matrix(alpha, beta, jnp.zeros_like(beta))
Expand Down

0 comments on commit 2375789

Please sign in to comment.