Skip to content

Commit

Permalink
Vectorize slerp and remove custom code for create_group (#8)
Browse files Browse the repository at this point in the history
Move business logic contained in the `Slerp` class to `_slerp` function
Remove the business logic for `create_group`, the `scipy` library does this already
  • Loading branch information
chrisflesher authored Jun 6, 2024
1 parent 5677b59 commit 40073f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 154 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jax-scipy-spatial"
version = "0.2.3"
version = "0.2.4"
description = "Scipy spatial API for JAX"
readme = "README.md"
authors = [
Expand Down
172 changes: 19 additions & 153 deletions src/jax_scipy_spatial/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import jax
import jax.numpy as jnp
from scipy.constants import golden
try:
from jax._src.numpy.util import implements
except ImportError:
Expand Down Expand Up @@ -62,35 +61,8 @@ def align_vectors(cls, a: jax.Array, b: jax.Array, weights: typing.Optional[jax.
@classmethod
def create_group(cls, group: str, axis: str = 'Z', dtype=float):
"""Create a 3D rotation group."""
if not isinstance(group, str):
raise ValueError("`group` argument must be a string")
permitted_axes = ['x', 'y', 'z', 'X', 'Y', 'Z']
if axis not in permitted_axes:
raise ValueError("`axis` must be one of " + ", ".join(permitted_axes))
if group in ['I', 'O', 'T']:
symbol = group
order = 1
elif group[:1] in ['C', 'D'] and group[1:].isdigit():
symbol = group[:1]
order = int(group[1:])
else:
raise ValueError("`group` must be one of 'I', 'O', 'T', 'Dn', 'Cn'")
axis_index = _elementary_basis_index(axis.lower())
if order < 1:
raise ValueError("Group order must be positive")
if symbol == 'I':
quat = _create_icosahedral_group()
elif symbol == 'O':
quat = _create_octahedral_group()
elif symbol == 'T':
quat = _create_tetrahedral_group()
elif symbol == 'D':
quat = _create_dicyclic_group(order, axis=axis_index)
elif symbol == 'C':
quat = _create_cyclic_group(order, axis=axis_index)
else:
assert False
return cls.from_quat(quat)
quat = scipy.spatial.transform.Rotation.create_group(group, axis).as_quat()
return cls.from_quat(jnp.array(quat, dtype=dtype))

@classmethod
def concatenate(cls, rotations: typing.Sequence):
Expand Down Expand Up @@ -260,34 +232,18 @@ def __init__(self, times: jax.Array, rotations: Rotation):
if times.shape[0] != len(rotations):
raise ValueError("Expected number of rotations to be equal to number of timestamps given, got "
"{} rotations and {} timestamps.".format(len(rotations), times.shape[0]))
timedelta = jnp.diff(times)
# if jnp.any(timedelta <= 0): # this causes a concretization error...
# raise ValueError("Times must be in strictly increasing order.")
new_rotations = Rotation(rotations.as_quat()[:-1])
self._times = times
self._timedelta = timedelta
self._rotations = new_rotations
self._rotvecs = (new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec()
self._rotations = rotations

def __call__(self, times: jax.Array) -> jax.Array:
def __call__(self, times: jax.Array) -> Rotation:
"""Interpolate rotations."""
compute_times = jnp.asarray(times, dtype=self._times.dtype)
if compute_times.ndim > 1:
raise ValueError("`times` must be at most 1-dimensional.")
single_time = compute_times.ndim == 0
compute_times = jnp.atleast_1d(compute_times)
ind = jnp.maximum(jnp.searchsorted(self._times, compute_times) - 1, 0)
alpha = (compute_times - self._times[ind]) / self._timedelta[ind]
result = (self._rotations[ind] * Rotation.from_rotvec(self._rotvecs[ind] * alpha[:, None]))
if single_time:
return result[0]
return result
return Rotation(_slerp(times, self._times, self._rotations.as_quat()))


jax.tree_util.register_pytree_node(
Slerp,
lambda obj: ((obj._times, obj._timedelta, obj._rotations, obj._rotvecs), None),
lambda aux, children: Rotation(*children),
lambda obj: ((obj._times, obj._rotations), None),
lambda aux, children: Slerp(*children),
)


Expand Down Expand Up @@ -394,108 +350,6 @@ def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool,
return jnp.where(degrees, jnp.rad2deg(angles), angles)


def _create_cyclic_group(n: int, axis: int = 2) -> jax.Array:
thetas = jnp.linspace(0, 2 * jnp.pi, n, endpoint=False)
rv = jnp.vstack([thetas, jnp.zeros(n), jnp.zeros(n)]).T
return _from_rotvec(jnp.roll(rv, axis, axis=1), False)


def _create_dicyclic_group(n: int, axis: int = 2) -> jax.Array:
g1 = _as_rotvec(_create_cyclic_group(n, axis), False)
thetas = jnp.linspace(0, jnp.pi, n, endpoint=False)
rv = jnp.pi * jnp.vstack([jnp.zeros(n), jnp.cos(thetas), jnp.sin(thetas)]).T
g2 = jnp.roll(rv, axis, axis=1)
return _from_rotvec(jnp.concatenate((g1, g2)), False)


def _create_icosahedral_group() -> jax.Array:
g1 = _create_tetrahedral_group()
a = 0.5
b = 0.5 / golden
c = golden / 2
g2 = jnp.array([[+a, +b, +c, 0],
[+a, +b, -c, 0],
[+a, +c, 0, +b],
[+a, +c, 0, -b],
[+a, -b, +c, 0],
[+a, -b, -c, 0],
[+a, -c, 0, +b],
[+a, -c, 0, -b],
[+a, 0, +b, +c],
[+a, 0, +b, -c],
[+a, 0, -b, +c],
[+a, 0, -b, -c],
[+b, +a, 0, +c],
[+b, +a, 0, -c],
[+b, +c, +a, 0],
[+b, +c, -a, 0],
[+b, -a, 0, +c],
[+b, -a, 0, -c],
[+b, -c, +a, 0],
[+b, -c, -a, 0],
[+b, 0, +c, +a],
[+b, 0, +c, -a],
[+b, 0, -c, +a],
[+b, 0, -c, -a],
[+c, +a, +b, 0],
[+c, +a, -b, 0],
[+c, +b, 0, +a],
[+c, +b, 0, -a],
[+c, -a, +b, 0],
[+c, -a, -b, 0],
[+c, -b, 0, +a],
[+c, -b, 0, -a],
[+c, 0, +a, +b],
[+c, 0, +a, -b],
[+c, 0, -a, +b],
[+c, 0, -a, -b],
[0, +a, +c, +b],
[0, +a, +c, -b],
[0, +a, -c, +b],
[0, +a, -c, -b],
[0, +b, +a, +c],
[0, +b, +a, -c],
[0, +b, -a, +c],
[0, +b, -a, -c],
[0, +c, +b, +a],
[0, +c, +b, -a],
[0, +c, -b, +a],
[0, +c, -b, -a]])
return jnp.concatenate((g1, g2))


def _create_octahedral_group() -> jax.Array:
g1 = _create_tetrahedral_group()
c = jnp.sqrt(2) / 2
g2 = jnp.array([[+c, 0, 0, +c],
[0, +c, 0, +c],
[0, 0, +c, +c],
[0, 0, -c, +c],
[0, -c, 0, +c],
[-c, 0, 0, +c],
[0, +c, +c, 0],
[0, -c, +c, 0],
[+c, 0, +c, 0],
[-c, 0, +c, 0],
[+c, +c, 0, 0],
[-c, +c, 0, 0]])
return jnp.concatenate((g1, g2))


def _create_tetrahedral_group() -> jax.Array:
g1 = jnp.eye(4)
c = 0.5
g2 = jnp.array([[c, -c, -c, +c],
[c, -c, +c, +c],
[c, +c, -c, +c],
[c, +c, +c, +c],
[c, -c, -c, -c],
[c, -c, +c, -c],
[c, +c, -c, -c],
[c, +c, +c, -c]])
return jnp.concatenate((g1, g2))


def _elementary_basis_index(axis: str) -> int:
if axis == 'x':
return 0
Expand Down Expand Up @@ -616,6 +470,18 @@ def _reduce(p: jax.Array, l: jax.Array, r: jax.Array) -> jax.Array:
return reduced, left_best, right_best


@functools.partial(jnp.vectorize, signature='(),(m),(m,n)->(n)')
def _slerp(time: jax.Array, times: jax.Array, quats: jax.Array) -> jax.Array:
rotations = Rotation(quats)
clipped_time = jnp.clip(time, times[0], times[-1])
ind = jnp.clip(jnp.searchsorted(times, clipped_time, side='right'), 1, times.size - 1)
timedelta = times[ind] - times[ind-1]
alpha = jnp.where(timedelta == 0., 0., (clipped_time - times[ind-1]) / timedelta)
rotvec = (rotations[ind] * rotations[ind-1].inv()).as_rotvec()
result = Rotation.from_rotvec(rotvec * alpha) * rotations[ind-1]
return result.as_quat()


def _split_quaternion(q: jax.Array) -> jax.Array:
q = jnp.atleast_2d(q)
return q[:, -1], q[:, :-1]
Expand Down

0 comments on commit 40073f3

Please sign in to comment.