From d01a108439e384184ae92604d7928f70ea425568 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 5 May 2024 23:34:50 -0700 Subject: [PATCH] Support batch axes + broadcasting for `viser.transforms` --- examples/08_smpl_visualizer.py | 11 +- src/viser/transforms/_base.py | 105 +++++++---- src/viser/transforms/_se2.py | 172 ++++++++++------- src/viser/transforms/_se3.py | 162 ++++++++-------- src/viser/transforms/_so2.py | 90 +++++---- src/viser/transforms/_so3.py | 247 ++++++++++++++----------- src/viser/transforms/hints/__init__.py | 4 +- src/viser/transforms/utils/__init__.py | 4 +- src/viser/transforms/utils/_utils.py | 57 +++++- 9 files changed, 499 insertions(+), 353 deletions(-) diff --git a/examples/08_smpl_visualizer.py b/examples/08_smpl_visualizer.py index aa7e9bff9..2d3377735 100644 --- a/examples/08_smpl_visualizer.py +++ b/examples/08_smpl_visualizer.py @@ -101,13 +101,10 @@ def main(model_path: Path) -> None: # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), - joint_rotmats=np.stack( - [ - tf.SO3.exp(np.array(x.value)).as_matrix() - for x in gui_elements.gui_joints - ], - axis=0, - ), + joint_rotmats=tf.SO3.exp( + # (num_joints, 3) + np.array([x.value for x in gui_elements.gui_joints]) + ).as_matrix(), ) server.add_mesh_simple( "/human", diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index f78b52b4a..04806e76c 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -1,15 +1,12 @@ import abc -from typing import ClassVar, Generic, Type, TypeVar, Union, overload +from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload import numpy as onp import numpy.typing as onpt -from typing_extensions import Self, final, override +from typing_extensions import Self, final, get_args, override from . import hints -GroupType = TypeVar("GroupType", bound="MatrixLieGroup") -SEGroupType = TypeVar("SEGroupType", bound="SEBase") - class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" @@ -29,27 +26,35 @@ class MatrixLieGroup(abc.ABC): space_dim: ClassVar[int] """Dimension of coordinates that can be transformed.""" - def __init__(self, parameters: onpt.NDArray[onp.floating], /): + def __init__( + # Notes: + # - For the constructor signature to be consistent with subclasses, `parameters` + # should be marked as positional-only. But this isn't possible in Python 3.7. + # - This method is implicitly overriden by the dataclass decorator and + # should _not_ be marked abstract. + self, + parameters: onp.ndarray, + ): """Construct a group object from its underlying parameters.""" raise NotImplementedError() # Shared implementations. @overload - def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: ... + def __matmul__(self, other: Self) -> Self: ... @overload - def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... + def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: ... def __matmul__( - self: GroupType, other: Union[GroupType, hints.Array] - ) -> Union[GroupType, onpt.NDArray[onp.floating]]: + self, other: Union[Self, hints.Array] + ) -> Union[Self, onpt.NDArray[onp.floating]]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication (`.multiply()`) based on the type of `other`. """ - if isinstance(other, onp.ndarray): + if isinstance(other, (onp.ndarray, onp.ndarray)): return self.apply(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim @@ -61,16 +66,19 @@ def __matmul__( @classmethod @abc.abstractmethod - def identity(cls: Type[GroupType]) -> GroupType: + def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self: """Returns identity element. + Args: + batch_axes: Any leading batch axes for the output transform. + Returns: Identity element. """ @classmethod @abc.abstractmethod - def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType: + def from_matrix(cls, matrix: hints.Array) -> Self: """Get group member from matrix representation. Args: @@ -104,7 +112,7 @@ def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: """ @abc.abstractmethod - def multiply(self: Self, other: Self) -> Self: + def multiply(self, other: Self) -> Self: """Composes this transformation with another. Returns: @@ -113,7 +121,7 @@ def multiply(self: Self, other: Self) -> Self: @classmethod @abc.abstractmethod - def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType: + def exp(cls, tangent: hints.Array) -> Self: """Computes `expm(wedge(tangent))`. Args: @@ -149,7 +157,7 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: """ @abc.abstractmethod - def inverse(self: Self) -> Self: + def inverse(self) -> Self: """Computes the inverse of our transform. Returns: @@ -157,25 +165,34 @@ def inverse(self: Self) -> Self: """ @abc.abstractmethod - def normalize(self: Self) -> Self: + def normalize(self) -> Self: """Normalize/projects values and returns. Returns: - GroupType: Normalized group member. + Normalized group member. """ - # @classmethod - # @abc.abstractmethod - # def sample_uniform(cls: Type[GroupType], key: hints.KeyArray) -> GroupType: - # """Draw a uniform sample from the group. Translations (if applicable) are in the - # range [-1, 1]. + # @classmethod + # @abc.abstractmethod + # def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self: + # """Draw a uniform sample from the group. Translations (if applicable) are in the + # range [-1, 1]. # - # Args: - # key: PRNG key, as returned by `jax.random.PRNGKey()`. + # Args: + # key: PRNG key, as returned by `jax.random.PRNGKey()`. + # batch_axes: Any leading batch axes for the output transforms. Each + # sampled transform will be different. # - # Returns: - # Sampled group member. - # """ + # Returns: + # Sampled group member. + # """ + + @final + def get_batch_axes(self) -> Tuple[int, ...]: + """Return any leading batch axes in contained parameters. If an array of shape + `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will + return `(100,)`.""" + return self.parameters().shape[:-1] class SOBase(MatrixLieGroup): @@ -197,10 +214,10 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup): @classmethod @abc.abstractmethod def from_rotation_and_translation( - cls: Type[SEGroupType], + cls, rotation: ContainedSOType, translation: hints.Array, - ) -> SEGroupType: + ) -> Self: """Construct a rigid transform from a rotation and a translation. Args: @@ -213,18 +230,24 @@ def from_rotation_and_translation( @final @classmethod - def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: + def from_rotation(cls, rotation: ContainedSOType) -> Self: return cls.from_rotation_and_translation( rotation=rotation, - translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype), + translation=onp.zeros( + (*rotation.get_batch_axes(), cls.space_dim), + dtype=rotation.parameters().dtype, + ), ) + @final @classmethod - @abc.abstractmethod - def from_translation( - cls: Type[SEGroupType], translation: onpt.NDArray[onp.floating] - ) -> SEGroupType: - """Construct a transform from a translation term.""" + def from_translation(cls, translation: hints.Array) -> Self: + # Extract rotation class from type parameter. + assert len(cls.__orig_bases__) == 1 # type: ignore + return cls.from_rotation_and_translation( + rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore + translation=translation, + ) @abc.abstractmethod def rotation(self) -> ContainedSOType: @@ -241,9 +264,9 @@ def translation(self) -> onpt.NDArray[onp.floating]: def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: return self.rotation() @ target + self.translation() # type: ignore - @override @final - def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: + @override + def multiply(self, other: Self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation() @ other.rotation(), translation=(self.rotation() @ other.translation()) + self.translation(), @@ -251,7 +274,7 @@ def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: @final @override - def inverse(self: SEGroupType) -> SEGroupType: + def inverse(self) -> Self: R_inv = self.rotation().inverse() return type(self).from_rotation_and_translation( rotation=R_inv, @@ -260,7 +283,7 @@ def inverse(self: SEGroupType) -> SEGroupType: @final @override - def normalize(self: SEGroupType) -> SEGroupType: + def normalize(self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index f7a7b15df..45dfd7b77 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -1,5 +1,5 @@ import dataclasses -from typing import cast +from typing import Tuple, cast import numpy as onp import numpy.typing as onpt @@ -7,7 +7,7 @@ from . import _base, hints from ._so2 import SO2 -from .utils import get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon, register_lie_group @register_lie_group( @@ -16,11 +16,10 @@ tangent_dim=3, space_dim=2, ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class SE2(_base.SEBase[SO2]): - """Special Euclidean group for proper rigid transforms in 2D. - - Ported to numpy from `jaxlie.SE2`. + """Special Euclidean group for proper rigid transforms in 2D. Broadcasting + rules are the same as for numpy. Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx, vy, omega)`. @@ -28,8 +27,8 @@ class SE2(_base.SEBase[SO2]): # SE2-specific. - unit_complex_xy: onpt.NDArray[onp.floating] - """Internal parameters. `(cos, sin, x, y)`.""" + unit_complex_xy: onp.ndarray + """Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 3)`.""" @override def __repr__(self) -> str: @@ -45,7 +44,7 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2 """ cos = onp.cos(theta) sin = onp.sin(theta) - return SE2(unit_complex_xy=onp.array([cos, sin, x, y])) + return SE2(unit_complex_xy=onp.stack([cos, sin, x, y], axis=-1)) # SE-specific. @@ -56,16 +55,14 @@ def from_rotation_and_translation( rotation: SO2, translation: hints.Array, ) -> "SE2": - assert translation.shape == (2,) + assert translation.shape[-1:] == (2,) + rotation, translation = broadcast_leading_axes((rotation, translation)) return SE2( - unit_complex_xy=onp.concatenate([rotation.unit_complex, translation]) + unit_complex_xy=onp.concatenate( + [rotation.unit_complex, translation], axis=-1 + ) ) - @override - @classmethod - def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> "SE2": - return SE2.from_rotation_and_translation(SO2.identity(), translation) - @override def rotation(self) -> SO2: return SO2(unit_complex=self.unit_complex_xy[..., :2]) @@ -78,17 +75,21 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls) -> "SE2": - return SE2(unit_complex_xy=onp.array([1.0, 0.0, 0.0, 0.0])) + def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2": + return SE2( + unit_complex_xy=onp.broadcast_to( + onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) + ) + ) @classmethod @override def from_matrix(cls, matrix: hints.Array) -> "SE2": - assert matrix.shape == (3, 3) + assert matrix.shape[-2:] == (3, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( - rotation=SO2.from_matrix(matrix[:2, :2]), - translation=matrix[:2, 2], + rotation=SO2.from_matrix(matrix[..., :2, :2]), + translation=matrix[..., :2, 2], ) # Accessors. @@ -99,14 +100,22 @@ def parameters(self) -> onpt.NDArray[onp.floating]: @override def as_matrix(self) -> onpt.NDArray[onp.floating]: - cos, sin, x, y = self.unit_complex_xy - return onp.array( + cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) + out = onp.stack( [ - [cos, -sin, x], - [sin, cos, y], - [0.0, 0.0, 1.0], - ] - ) + cos, + -sin, + x, + sin, + cos, + y, + onp.zeros_like(x), + onp.zeros_like(x), + onp.ones_like(x), + ], + axis=-1, + ).reshape((*self.get_batch_axes(), 3, 3)) + return out # Operations. @@ -118,25 +127,25 @@ def exp(cls, tangent: hints.Array) -> "SE2": # Also see: # > http://ethaneade.com/lie.pdf - assert tangent.shape == (3,) + assert tangent.shape[-1:] == (3,) - theta = tangent[2] + theta = tangent[..., 2] use_taylor = onp.abs(theta) < get_epsilon(tangent.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. safe_theta = cast( - onpt.NDArray[onp.floating], + onp.ndarray, onp.where( use_taylor, - 1.0, # Any non-zero value should do here. + onp.ones_like(theta), # Any non-zero value should do here. theta, ), ) theta_sq = theta**2 sin_over_theta = cast( - onpt.NDArray[onp.floating], + onp.ndarray, onp.where( use_taylor, 1.0 - theta_sq / 6.0, @@ -144,7 +153,7 @@ def exp(cls, tangent: hints.Array) -> "SE2": ), ) one_minus_cos_over_theta = cast( - onpt.NDArray[onp.floating], + onp.ndarray, onp.where( use_taylor, 0.5 * theta - theta * theta_sq / 24.0, @@ -152,15 +161,18 @@ def exp(cls, tangent: hints.Array) -> "SE2": ), ) - V = onp.array( + V = onp.stack( [ - [sin_over_theta, -one_minus_cos_over_theta], - [one_minus_cos_over_theta, sin_over_theta], - ] - ) + sin_over_theta, + -one_minus_cos_over_theta, + one_minus_cos_over_theta, + sin_over_theta, + ], + axis=-1, + ).reshape((*tangent.shape[:-1], 2, 2)) return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), - translation=V @ tangent[:2], + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]), ) @override @@ -170,7 +182,7 @@ def log(self) -> onpt.NDArray[onp.floating]: # Also see: # > http://ethaneade.com/lie.pdf - theta = self.rotation().log()[0] + theta = self.rotation().log()[..., 0] cos = onp.cos(theta) cos_minus_one = cos - 1.0 @@ -178,10 +190,10 @@ def log(self) -> onpt.NDArray[onp.floating]: use_taylor = onp.abs(cos_minus_one) < get_epsilon(theta.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. safe_cos_minus_one = onp.where( use_taylor, - 1.0, # Any non-zero value should do here. + onp.ones_like(cos_minus_one), # Any non-zero value should do here. cos_minus_one, ) @@ -193,34 +205,58 @@ def log(self) -> onpt.NDArray[onp.floating]: -(half_theta * onp.sin(theta)) / safe_cos_minus_one, ) - V_inv = onp.array( + V_inv = onp.stack( + [ + half_theta_over_tan_half_theta, + half_theta, + -half_theta, + half_theta_over_tan_half_theta, + ], + axis=-1, + ).reshape((*theta.shape, 2, 2)) + + tangent = onp.concatenate( [ - [half_theta_over_tan_half_theta, half_theta], - [-half_theta, half_theta_over_tan_half_theta], - ] + onp.einsum("...ij,...j->...i", V_inv, self.translation()), + theta[..., None], + ], + axis=-1, ) - - tangent = onp.concatenate([V_inv @ self.translation(), theta[None]]) return tangent @override def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: - cos, sin, x, y = self.unit_complex_xy - return onp.array( + cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) + return onp.stack( [ - [cos, -sin, y], - [sin, cos, -x], - [0.0, 0.0, 1.0], - ] - ) - - # @staticmethod - # @override - # def sample_uniform(key: hints.KeyArray) -> "SE2": - # key0, key1 = jax.random.split(key) - # return SE2.from_rotation_and_translation( - # rotation=SO2.sample_uniform(key0), - # translation=jax.random.uniform( - # key=key1, shape=(2,), minval=-1.0, maxval=1.0 - # ), - # ) + cos, + -sin, + y, + sin, + cos, + -x, + onp.zeros_like(x), + onp.zeros_like(x), + onp.ones_like(x), + ], + axis=-1, + ).reshape((*self.get_batch_axes(), 3, 3)) + + # @classmethod + # @override + # def sample_uniform( + # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () + # ) -> "SE2": + # key0, key1 = jax.random.split(key) + # return SE2.from_rotation_and_translation( + # rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), + # translation=jax.random.uniform( + # key=key1, + # shape=( + # *batch_axes, + # 2, + # ), + # minval=-1.0, + # maxval=1.0, + # ), + # ) diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index 46690a081..406c2a0ff 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -1,28 +1,26 @@ from __future__ import annotations import dataclasses -from typing import cast +from typing import Tuple, cast import numpy as onp import numpy.typing as onpt from typing_extensions import override -from . import _base +from . import _base, hints from ._so3 import SO3 -from .utils import get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon, register_lie_group -def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: +def _skew(omega: hints.Array) -> onpt.NDArray[onp.floating]: """Returns the skew-symmetric form of a length-3 vector.""" - wx, wy, wz = omega - return onp.array( - [ # type: ignore - [0.0, -wz, wy], - [wz, 0.0, -wx], - [-wy, wx, 0.0], - ] - ) + wx, wy, wz = onp.moveaxis(omega, -1, 0) + zeros = onp.zeros_like(wx) + return onp.stack( + [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros], + axis=-1, + ).reshape((*omega.shape[:-1], 3, 3)) @register_lie_group( @@ -31,11 +29,10 @@ def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: tangent_dim=6, space_dim=3, ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class SE3(_base.SEBase[SO3]): - """Special Euclidean group for proper rigid transforms in 3D. - - Ported to numpy from `jaxlie.SE3`. + """Special Euclidean group for proper rigid transforms in 3D. Broadcasting + rules are the same as for numpy. Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization is `(vx, vy, vz, omega_x, omega_y, omega_z)`. @@ -43,8 +40,8 @@ class SE3(_base.SEBase[SO3]): # SE3-specific. - wxyz_xyz: onpt.NDArray[onp.floating] - """Internal parameters. wxyz quaternion followed by xyz translation.""" + wxyz_xyz: onp.ndarray + """Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`.""" @override def __repr__(self) -> str: @@ -59,15 +56,11 @@ def __repr__(self) -> str: def from_rotation_and_translation( cls, rotation: SO3, - translation: onpt.NDArray[onp.floating], + translation: hints.Array, ) -> SE3: - assert translation.shape == (3,) - return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation])) - - @override - @classmethod - def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> "SE3": - return SE3.from_rotation_and_translation(SO3.identity(), translation) + assert translation.shape[-1:] == (3,) + rotation, translation = broadcast_leading_axes((rotation, translation)) + return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation], axis=-1)) @override def rotation(self) -> SO3: @@ -81,34 +74,32 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls) -> SE3: - return SE3(wxyz_xyz=onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])) + def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3: + return SE3( + wxyz_xyz=onp.broadcast_to( + onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) + ) + ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: - assert matrix.shape == (4, 4) + def from_matrix(cls, matrix: hints.Array) -> SE3: + assert matrix.shape[-2:] == (4, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( - rotation=SO3.from_matrix(matrix[:3, :3]), - translation=matrix[:3, 3], + rotation=SO3.from_matrix(matrix[..., :3, :3]), + translation=matrix[..., :3, 3], ) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: - out = onp.eye(4) - out[:3, :3] = self.rotation().as_matrix() - out[:3, 3] = self.translation() + out = onp.zeros((*self.get_batch_axes(), 4, 4)) + out[..., :3, :3] = self.rotation().as_matrix() + out[..., :3, 3] = set(self.translation()) + out[..., 3, 3] = 1.0 return out - # return ( - # onp.eye(4) - # .at[:3, :3] - # .set(self.rotation().as_matrix()) - # .at[:3, 3] - # .set(self.translation()) - # ) @override def parameters(self) -> onpt.NDArray[onp.floating]: @@ -118,47 +109,50 @@ def parameters(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: + def exp(cls, tangent: hints.Array) -> SE3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 # (x, y, z, omega_x, omega_y, omega_z) - assert tangent.shape == (6,) + assert tangent.shape[-1:] == (6,) - rotation = SO3.exp(tangent[3:]) + rotation = SO3.exp(tangent[..., 3:]) - theta_squared = tangent[3:] @ tangent[3:] + theta_squared = onp.sum(onp.square(tangent[..., 3:]), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. theta_squared_safe = cast( - onpt.NDArray[onp.floating], + onp.ndarray, onp.where( use_taylor, - 1.0, # Any non-zero value should do here. + onp.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ), ) del theta_squared theta_safe = onp.sqrt(theta_squared_safe) - skew_omega = _skew(tangent[3:]) + skew_omega = _skew(tangent[..., 3:]) V = onp.where( - use_taylor, + use_taylor[..., None, None], rotation.as_matrix(), ( onp.eye(3) - + (1.0 - onp.cos(theta_safe)) / (theta_squared_safe) * skew_omega - + (theta_safe - onp.sin(theta_safe)) - / (theta_squared_safe * theta_safe) - * (skew_omega @ skew_omega) + + ((1.0 - onp.cos(theta_safe)) / (theta_squared_safe))[..., None, None] + * skew_omega + + ( + (theta_safe - onp.sin(theta_safe)) + / (theta_squared_safe * theta_safe) + )[..., None, None] + * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) return SE3.from_rotation_and_translation( rotation=rotation, - translation=V @ tangent[:3], + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]), ) @override @@ -166,16 +160,16 @@ def log(self) -> onpt.NDArray[onp.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 omega = self.rotation().log() - theta_squared = omega @ omega + theta_squared = onp.sum(onp.square(omega), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) skew_omega = _skew(omega) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. theta_squared_safe = onp.where( use_taylor, - 1.0, # Any non-zero value should do here. + onp.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ) del theta_squared @@ -183,40 +177,54 @@ def log(self) -> onpt.NDArray[onp.floating]: half_theta_safe = theta_safe / 2.0 V_inv = onp.where( - use_taylor, - onp.eye(3) - 0.5 * skew_omega + (skew_omega @ skew_omega) / 12.0, + use_taylor[..., None, None], + onp.eye(3) + - 0.5 * skew_omega + + onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0, ( onp.eye(3) - 0.5 * skew_omega + ( - 1.0 - - theta_safe - * onp.cos(half_theta_safe) - / (2.0 * onp.sin(half_theta_safe)) - ) - / theta_squared_safe - * (skew_omega @ skew_omega) + ( + 1.0 + - theta_safe + * onp.cos(half_theta_safe) + / (2.0 * onp.sin(half_theta_safe)) + ) + / theta_squared_safe + )[..., None, None] + * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) - return onp.concatenate([V_inv @ self.translation(), omega]) + return onp.concatenate( + [onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 + ) @override def adjoint(self) -> onpt.NDArray[onp.floating]: R = self.rotation().as_matrix() - return onp.block( + return onp.concatenate( [ - [R, _skew(self.translation()) @ R], - [onp.zeros((3, 3)), R], - ] + onp.concatenate( + [R, onp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)], + axis=-1, + ), + onp.concatenate( + [onp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 + ), + ], + axis=-2, ) - # @staticmethod + # @classmethod # @override - # def sample_uniform(key: hints.KeyArray) -> SE3: + # def sample_uniform( + # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () + # ) -> SE3: # key0, key1 = jax.random.split(key) # return SE3.from_rotation_and_translation( - # rotation=SO3.sample_uniform(key0), + # rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), # translation=jax.random.uniform( - # key=key1, shape=(3,), minval=-1.0, maxval=1.0 + # key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0 # ), # ) diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index 1b9a5caa5..b9189c2fe 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -1,13 +1,14 @@ from __future__ import annotations import dataclasses +from typing import Tuple import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base, hints -from .utils import register_lie_group +from .utils import broadcast_leading_axes, register_lie_group @register_lie_group( @@ -16,19 +17,18 @@ tangent_dim=1, space_dim=2, ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class SO2(_base.SOBase): - """Special orthogonal group for 2D rotations. - - Ported to numpy from `jaxlie.SO2`. + """Special orthogonal group for 2D rotations. Broadcasting rules are the + same as for `numpy`. Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`. """ # SO2-specific. - unit_complex: onpt.NDArray[onp.floating] - """Internal parameters. `(cos, sin)`.""" + unit_complex: onp.ndarray + """Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`.""" @override def __repr__(self) -> str: @@ -40,7 +40,7 @@ def from_radians(theta: hints.Scalar) -> SO2: """Construct a rotation object from a scalar angle.""" cos = onp.cos(theta) sin = onp.sin(theta) - return SO2(unit_complex=onp.array([cos, sin])) + return SO2(unit_complex=onp.stack([cos, sin], axis=-1)) def as_radians(self) -> onpt.NDArray[onp.floating]: """Compute a scalar angle from a rotation object.""" @@ -51,30 +51,35 @@ def as_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls) -> SO2: - return SO2(unit_complex=onp.array([1.0, 0.0])) + def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: + return SO2( + unit_complex=onp.stack( + [onp.ones(batch_axes), onp.zeros(batch_axes)], axis=-1 + ) + ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO2: - assert matrix.shape == (2, 2) - return SO2(unit_complex=onp.asarray(matrix[:, 0])) + def from_matrix(cls, matrix: hints.Array) -> SO2: + assert matrix.shape[-2:] == (2, 2) + return SO2(unit_complex=onp.asarray(matrix[..., :, 0])) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: cos_sin = self.unit_complex - out = onp.array( + out = onp.stack( [ # [cos, -sin], cos_sin * onp.array([1, -1]), # [sin, cos], - cos_sin[::-1], - ] + cos_sin[..., ::-1], + ], + axis=-2, ) - assert out.shape == (2, 2) - return out + assert out.shape == (*self.get_batch_axes(), 2, 2) + return out # type: ignore @override def parameters(self) -> onpt.NDArray[onp.floating]: @@ -83,21 +88,26 @@ def parameters(self) -> onpt.NDArray[onp.floating]: # Operations. @override - def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: - assert target.shape == (2,) - return self.as_matrix() @ target # type: ignore + def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: + assert target.shape[-1:] == (2,) + self, target = broadcast_leading_axes((self, target)) + return onp.einsum("...ij,...j->...i", self.as_matrix(), target) @override def multiply(self, other: SO2) -> SO2: - return SO2(unit_complex=self.as_matrix() @ other.unit_complex) + return SO2( + unit_complex=onp.einsum( + "...ij,...j->...i", self.as_matrix(), other.unit_complex + ) + ) @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO2: - (theta,) = tangent - cos = onp.cos(theta) - sin = onp.sin(theta) - return SO2(unit_complex=onp.array([cos, sin])) + def exp(cls, tangent: hints.Array) -> SO2: + assert tangent.shape[-1] == 1 + cos = onp.cos(tangent) + sin = onp.sin(tangent) + return SO2(unit_complex=onp.concatenate([cos, sin], axis=-1)) @override def log(self) -> onpt.NDArray[onp.floating]: @@ -107,7 +117,7 @@ def log(self) -> onpt.NDArray[onp.floating]: @override def adjoint(self) -> onpt.NDArray[onp.floating]: - return onp.eye(1) + return onp.ones((*self.get_batch_axes(), 1, 1)) @override def inverse(self) -> SO2: @@ -115,11 +125,19 @@ def inverse(self) -> SO2: @override def normalize(self) -> SO2: - return SO2(unit_complex=self.unit_complex / onp.linalg.norm(self.unit_complex)) - - # @staticmethod - # @override - # def sample_uniform(key: hints.KeyArray) -> SO2: - # return SO2.from_radians( - # jax.random.uniform(key=key, minval=0.0, maxval=2.0 * onp.pi) - # ) + return SO2( + unit_complex=self.unit_complex + / onp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) + ) + + # @classmethod + # @override + # def sample_uniform( + # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () + # ) -> SO2: + # out = SO2.from_radians( + # jax.random.uniform( + # key=key, shape=batch_axes, minval=0.0, maxval=2.0 * onp.pi) + # ) + # assert out.get_batch_axes() == batch_axes + # return out diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index b7575326f..ac6397ceb 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -1,13 +1,14 @@ from __future__ import annotations import dataclasses +from typing import Tuple import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base, hints -from .utils import get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon, register_lie_group @register_lie_group( @@ -16,20 +17,17 @@ tangent_dim=3, space_dim=3, ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class SO3(_base.SOBase): - """Special orthogonal group for 3D rotations. - - Ported to numpy from `jaxlie.SO3`. + """Special orthogonal group for 3D rotations. Broadcasting rules are the same as + for numpy. Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is `(omega_x, omega_y, omega_z)`. """ - # SO3-specific. - - wxyz: onpt.NDArray[onp.floating] - """Internal parameters. `(w, x, y, z)` quaternion.""" + wxyz: onp.ndarray + """Internal parameters. `(w, x, y, z)` quaternion. Shape should be `(*, 4)`.""" @override def __repr__(self) -> str: @@ -46,7 +44,8 @@ def from_x_radians(theta: hints.Scalar) -> SO3: Returns: Output. """ - return SO3.exp(onp.array([theta, 0.0, 0.0])) + zeros = onp.zeros_like(theta) + return SO3.exp(onp.stack([theta, zeros, zeros], axis=-1)) @staticmethod def from_y_radians(theta: hints.Scalar) -> SO3: @@ -58,7 +57,8 @@ def from_y_radians(theta: hints.Scalar) -> SO3: Returns: Output. """ - return SO3.exp(onp.array([0.0, theta, 0.0])) + zeros = onp.zeros_like(theta) + return SO3.exp(onp.stack([zeros, theta, zeros], axis=-1)) @staticmethod def from_z_radians(theta: hints.Scalar) -> SO3: @@ -70,7 +70,8 @@ def from_z_radians(theta: hints.Scalar) -> SO3: Returns: Output. """ - return SO3.exp(onp.array([0.0, 0.0, theta])) + zeros = onp.zeros_like(theta) + return SO3.exp(onp.stack([zeros, zeros, theta], axis=-1)) @staticmethod def from_rpy_radians( @@ -96,24 +97,24 @@ def from_rpy_radians( ) @staticmethod - def from_quaternion_xyzw(xyzw: onpt.NDArray[onp.floating]) -> SO3: + def from_quaternion_xyzw(xyzw: hints.Array) -> SO3: """Construct a rotation from an `xyzw` quaternion. Note that `wxyz` quaternions can be constructed using the default dataclass constructor. Args: - xyzw: xyzw quaternion. Shape should be (4,). + xyzw: xyzw quaternion. Shape should be (*, 4). Returns: Output. """ - assert xyzw.shape == (4,) - return SO3(onp.roll(xyzw, shift=1)) + assert xyzw.shape[-1:] == (4,) + return SO3(onp.roll(xyzw, axis=-1, shift=1)) def as_quaternion_xyzw(self) -> onpt.NDArray[onp.floating]: """Grab parameters as xyzw quaternion.""" - return onp.roll(self.wxyz, shift=-1) + return onp.roll(self.wxyz, axis=-1, shift=-1) def as_rpy_radians(self) -> hints.RollPitchYaw: """Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention. @@ -134,7 +135,7 @@ def compute_roll_radians(self) -> onpt.NDArray[onp.floating]: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion - q0, q1, q2, q3 = self.wxyz + q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) def compute_pitch_radians(self) -> onpt.NDArray[onp.floating]: @@ -144,7 +145,7 @@ def compute_pitch_radians(self) -> onpt.NDArray[onp.floating]: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion - q0, q1, q2, q3 = self.wxyz + q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arcsin(2 * (q0 * q2 - q3 * q1)) def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: @@ -154,70 +155,76 @@ def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion - q0, q1, q2, q3 = self.wxyz + q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) # Factory. - @override @classmethod - def identity(cls) -> SO3: - return SO3(wxyz=onp.array([1.0, 0.0, 0.0, 0.0])) - @override + def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3: + return SO3( + wxyz=onp.broadcast_to(onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) + ) + @classmethod - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO3: - assert matrix.shape == (3, 3) + @override + def from_matrix(cls, matrix: hints.Array) -> SO3: + assert matrix.shape[-2:] == (3, 3) # Modified from: # > "Converting a Rotation Matrix to a Quaternion" from Mike Day # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf def case0(m): - t = 1 + m[0, 0] - m[1, 1] - m[2, 2] - q = onp.array( + t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] + q = onp.stack( [ - m[2, 1] - m[1, 2], + m[..., 2, 1] - m[..., 1, 2], t, - m[1, 0] + m[0, 1], - m[0, 2] + m[2, 0], - ] + m[..., 1, 0] + m[..., 0, 1], + m[..., 0, 2] + m[..., 2, 0], + ], + axis=-1, ) return t, q def case1(m): - t = 1 - m[0, 0] + m[1, 1] - m[2, 2] - q = onp.array( + t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] + q = onp.stack( [ - m[0, 2] - m[2, 0], - m[1, 0] + m[0, 1], + m[..., 0, 2] - m[..., 2, 0], + m[..., 1, 0] + m[..., 0, 1], t, - m[2, 1] + m[1, 2], - ] + m[..., 2, 1] + m[..., 1, 2], + ], + axis=-1, ) return t, q def case2(m): - t = 1 - m[0, 0] - m[1, 1] + m[2, 2] - q = onp.array( + t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] + q = onp.stack( [ - m[1, 0] - m[0, 1], - m[0, 2] + m[2, 0], - m[2, 1] + m[1, 2], + m[..., 1, 0] - m[..., 0, 1], + m[..., 0, 2] + m[..., 2, 0], + m[..., 2, 1] + m[..., 1, 2], t, - ] + ], + axis=-1, ) return t, q def case3(m): - t = 1 + m[0, 0] + m[1, 1] + m[2, 2] - q = onp.array( + t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] + q = onp.stack( [ t, - m[2, 1] - m[1, 2], - m[0, 2] - m[2, 0], - m[1, 0] - m[0, 1], - ] + m[..., 2, 1] - m[..., 1, 2], + m[..., 0, 2] - m[..., 2, 0], + m[..., 1, 0] - m[..., 0, 1], + ], + axis=-1, ) return t, q @@ -228,9 +235,9 @@ def case3(m): case2_t, case2_q = case2(matrix) case3_t, case3_q = case3(matrix) - cond0 = matrix[2, 2] < 0 - cond1 = matrix[0, 0] > matrix[1, 1] - cond2 = matrix[0, 0] < -matrix[1, 1] + cond0 = matrix[..., 2, 2] < 0 + cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] + cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] t = onp.where( cond0, @@ -238,9 +245,9 @@ def case3(m): onp.where(cond2, case2_t, case3_t), ) q = onp.where( - cond0, - onp.where(cond1, case0_q, case1_q), - onp.where(cond2, case2_q, case3_q), + cond0[..., None], + onp.where(cond1[..., None], case0_q, case1_q), + onp.where(cond2[..., None], case2_q, case3_q), ) # We can also choose to branch, but this is slower. @@ -261,22 +268,29 @@ def case3(m): # operand=matrix, # ) - return SO3(wxyz=q * 0.5 / onp.sqrt(t)) + return SO3(wxyz=q * 0.5 / onp.sqrt(t[..., None])) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: - norm = self.wxyz @ self.wxyz - q = self.wxyz * onp.sqrt(2.0 / norm) - q = onp.outer(q, q) - return onp.array( + norm_sq = onp.sum(onp.square(self.wxyz), axis=-1, keepdims=True) + q = self.wxyz * onp.sqrt(2.0 / norm_sq) # (*, 4) + q_outer = onp.einsum("...i,...j->...ij", q, q) # (*, 4, 4) + return onp.stack( [ - [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0]], - [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0]], - [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2]], - ] - ) + 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], + q_outer[..., 1, 2] - q_outer[..., 3, 0], + q_outer[..., 1, 3] + q_outer[..., 2, 0], + q_outer[..., 1, 2] + q_outer[..., 3, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], + q_outer[..., 2, 3] - q_outer[..., 1, 0], + q_outer[..., 1, 3] - q_outer[..., 2, 0], + q_outer[..., 2, 3] + q_outer[..., 1, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], + ], + axis=-1, + ).reshape(*q.shape[:-1], 3, 3) @override def parameters(self) -> onpt.NDArray[onp.floating]: @@ -285,45 +299,50 @@ def parameters(self) -> onpt.NDArray[onp.floating]: # Operations. @override - def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: - assert target.shape == (3,) + def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: + assert target.shape[-1:] == (3,) + self, target = broadcast_leading_axes((self, target)) # Compute using quaternion multiplys. - padded_target = onp.concatenate([onp.zeros(1), target]) - return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[1:] + padded_target = onp.concatenate( + [onp.zeros((*self.get_batch_axes(), 1)), target], axis=-1 + ) + return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:] @override def multiply(self, other: SO3) -> SO3: - w0, x0, y0, z0 = self.wxyz - w1, x1, y1, z1 = other.wxyz + w0, x0, y0, z0 = onp.moveaxis(self.wxyz, -1, 0) + w1, x1, y1, z1 = onp.moveaxis(other.wxyz, -1, 0) return SO3( - wxyz=onp.array( + wxyz=onp.stack( [ -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, - ] + ], + axis=-1, ) ) @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: + def exp(cls, tangent: hints.Array) -> SO3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 - assert tangent.shape == (3,) - theta_squared = tangent @ tangent + assert tangent.shape[-1:] == (3,) + + theta_squared = onp.sum(onp.square(tangent), axis=-1) theta_pow_4 = theta_squared * theta_squared use_taylor = theta_squared < get_epsilon(tangent.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. safe_theta = onp.sqrt( onp.where( use_taylor, - 1.0, # Any constant value should do here. + onp.ones_like(theta_squared), # Any constant value should do here. theta_squared, ) ) @@ -344,9 +363,10 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: return SO3( wxyz=onp.concatenate( [ - real_factor[None], - imaginary_factor * tangent, - ] + real_factor[..., None], + imaginary_factor[..., None] * tangent, + ], + axis=-1, ) ) @@ -356,11 +376,11 @@ def log(self) -> onpt.NDArray[onp.floating]: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 w = self.wxyz[..., 0] - norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:] + norm_sq = onp.sum(onp.square(self.wxyz[..., 1:]), axis=-1) use_taylor = norm_sq < get_epsilon(norm_sq.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for - # reverse-mode AD. (note: this is needed in JAX, but not in numpy) + # reverse-mode AD. norm_safe = onp.sqrt( onp.where( use_taylor, @@ -383,7 +403,7 @@ def log(self) -> onpt.NDArray[onp.floating]: ), ) - return atan_factor * self.wxyz[1:] + return atan_factor[..., None] * self.wxyz[..., 1:] # type: ignore @override def adjoint(self) -> onpt.NDArray[onp.floating]: @@ -396,29 +416,36 @@ def inverse(self) -> SO3: @override def normalize(self) -> SO3: - return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz)) - - # @staticmethod - # @override - # def sample_uniform(key: hints.KeyArray) -> SO3: - # # Uniformly sample over S^3. - # # > Reference: http://planning.cs.uiuc.edu/node198.html - # u1, u2, u3 = jax.random.uniform( - # key=key, - # shape=(3,), - # minval=onp.zeros(3), - # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), - # ) - # a = onp.sqrt(1.0 - u1) - # b = onp.sqrt(u1) + return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz, axis=-1, keepdims=True)) + + # @classmethod + # @override + # def sample_uniform( + # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () + # ) -> SO3: + # # Uniformly sample over S^3. + # # > Reference: http://planning.cs.uiuc.edu/node198.html + # u1, u2, u3 = onp.moveaxis( + # jax.random.uniform( + # key=key, + # shape=(*batch_axes, 3), + # minval=onp.zeros(3), + # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), + # ), + # -1, + # 0, + # ) + # a = onp.sqrt(1.0 - u1) + # b = onp.sqrt(u1) # - # return SO3( - # wxyz=onp.array( - # [ - # a * onp.sin(u2), - # a * onp.cos(u2), - # b * onp.sin(u3), - # b * onp.cos(u3), - # ] - # ) - # ) + # return SO3( + # wxyz=onp.stack( + # [ + # a * onp.sin(u2), + # a * onp.cos(u2), + # b * onp.sin(u3), + # b * onp.cos(u3), + # ], + # axis=-1, + # ) + # ) diff --git a/src/viser/transforms/hints/__init__.py b/src/viser/transforms/hints/__init__.py index 201a32f9b..ad619f26a 100644 --- a/src/viser/transforms/hints/__init__.py +++ b/src/viser/transforms/hints/__init__.py @@ -2,10 +2,10 @@ import numpy as onp -# Type aliases for JAX/Numpy arrays; primarily for function inputs. +# Type aliases Numpy arrays; primarily for function inputs. Array = onp.ndarray -"""Type alias for `onp.ndarray`.""" +"""Type alias for onp.ndarray.""" Scalar = Union[float, Array] """Type alias for `Union[float, Array]`.""" diff --git a/src/viser/transforms/utils/__init__.py b/src/viser/transforms/utils/__init__.py index 02371f657..11980074b 100644 --- a/src/viser/transforms/utils/__init__.py +++ b/src/viser/transforms/utils/__init__.py @@ -1,3 +1,3 @@ -from ._utils import get_epsilon, register_lie_group +from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group -__all__ = ["get_epsilon", "register_lie_group"] +__all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"] diff --git a/src/viser/transforms/utils/_utils.py b/src/viser/transforms/utils/_utils.py index dc7728e9e..d2a3ea380 100644 --- a/src/viser/transforms/utils/_utils.py +++ b/src/viser/transforms/utils/_utils.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast import numpy as onp +from jaxlie.hints import Array if TYPE_CHECKING: from .._base import MatrixLieGroup @@ -9,7 +10,7 @@ T = TypeVar("T", bound="MatrixLieGroup") -def get_epsilon(dtype: Any) -> float: +def get_epsilon(dtype: onp.dtype) -> float: """Helper for grabbing type-specific precision constants. Args: @@ -18,12 +19,10 @@ def get_epsilon(dtype: Any) -> float: Returns: Output float. """ - if dtype == onp.float32: - return 1e-5 - elif dtype == onp.float64: - return 1e-10 - else: - assert False + return { + onp.dtype("float32"): 1e-5, + onp.dtype("float64"): 1e-10, + }[dtype] def register_lie_group( @@ -35,7 +34,7 @@ def register_lie_group( ) -> Callable[[Type[T]], Type[T]]: """Decorator for registering Lie group dataclasses. - Sets dimensionality class variables, and (formerly in the JAX version) marks all methods for JIT compilation. + Sets dimensionality class variables, and marks all methods for JIT compilation. """ def _wrap(cls: Type[T]) -> Type[T]: @@ -44,7 +43,45 @@ def _wrap(cls: Type[T]) -> Type[T]: cls.parameters_dim = parameters_dim cls.tangent_dim = tangent_dim cls.space_dim = space_dim - return cls return _wrap + + +TupleOfBroadcastable = TypeVar( + "TupleOfBroadcastable", + bound="Tuple[Union[MatrixLieGroup, Array], ...]", +) + + +def broadcast_leading_axes(inputs: TupleOfBroadcastable) -> TupleOfBroadcastable: + """Broadcast leading axes of arrays. Takes tuples of either: + - an array, which we assume has shape (*, D). + - a Lie group object.""" + + from .._base import MatrixLieGroup + + array_inputs = [ + ( + (x.parameters(), (x.parameters_dim,)) + if isinstance(x, MatrixLieGroup) + else (x, x.shape[-1:]) + ) + for x in inputs + ] + for array, shape_suffix in array_inputs: + assert array.shape[-len(shape_suffix) :] == shape_suffix + batch_axes = onp.broadcast_shapes( + *[array.shape[: -len(suffix)] for array, suffix in array_inputs] + ) + broadcasted_arrays = tuple( + onp.broadcast_to(array, batch_axes + shape_suffix) + for (array, shape_suffix) in array_inputs + ) + return cast( + TupleOfBroadcastable, + tuple( + array if not isinstance(inp, MatrixLieGroup) else type(inp)(array) + for array, inp in zip(broadcasted_arrays, inputs) + ), + )