Skip to content

Commit

Permalink
Support batch axes + broadcasting for viser.transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 6, 2024
1 parent 1fe8437 commit d01a108
Show file tree
Hide file tree
Showing 9 changed files with 499 additions and 353 deletions.
11 changes: 4 additions & 7 deletions examples/08_smpl_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,10 @@ def main(model_path: Path) -> None:
# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.stack(
[
tf.SO3.exp(np.array(x.value)).as_matrix()
for x in gui_elements.gui_joints
],
axis=0,
),
joint_rotmats=tf.SO3.exp(
# (num_joints, 3)
np.array([x.value for x in gui_elements.gui_joints])
).as_matrix(),
)
server.add_mesh_simple(
"/human",
Expand Down
105 changes: 64 additions & 41 deletions src/viser/transforms/_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import abc
from typing import ClassVar, Generic, Type, TypeVar, Union, overload
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload

import numpy as onp
import numpy.typing as onpt
from typing_extensions import Self, final, override
from typing_extensions import Self, final, get_args, override

from . import hints

GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
SEGroupType = TypeVar("SEGroupType", bound="SEBase")


class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""
Expand All @@ -29,27 +26,35 @@ class MatrixLieGroup(abc.ABC):
space_dim: ClassVar[int]
"""Dimension of coordinates that can be transformed."""

def __init__(self, parameters: onpt.NDArray[onp.floating], /):
def __init__(
# Notes:
# - For the constructor signature to be consistent with subclasses, `parameters`
# should be marked as positional-only. But this isn't possible in Python 3.7.
# - This method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
self,
parameters: onp.ndarray,
):
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()

# Shared implementations.

@overload
def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: ...
def __matmul__(self, other: Self) -> Self: ...

@overload
def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...
def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: ...

def __matmul__(
self: GroupType, other: Union[GroupType, hints.Array]
) -> Union[GroupType, onpt.NDArray[onp.floating]]:
self, other: Union[Self, hints.Array]
) -> Union[Self, onpt.NDArray[onp.floating]]:
"""Overload for the `@` operator.
Switches between the group action (`.apply()`) and multiplication
(`.multiply()`) based on the type of `other`.
"""
if isinstance(other, onp.ndarray):
if isinstance(other, (onp.ndarray, onp.ndarray)):
return self.apply(target=other)
elif isinstance(other, MatrixLieGroup):
assert self.space_dim == other.space_dim
Expand All @@ -61,16 +66,19 @@ def __matmul__(

@classmethod
@abc.abstractmethod
def identity(cls: Type[GroupType]) -> GroupType:
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
Returns:
Identity element.
"""

@classmethod
@abc.abstractmethod
def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType:
def from_matrix(cls, matrix: hints.Array) -> Self:
"""Get group member from matrix representation.
Args:
Expand Down Expand Up @@ -104,7 +112,7 @@ def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]:
"""

@abc.abstractmethod
def multiply(self: Self, other: Self) -> Self:
def multiply(self, other: Self) -> Self:
"""Composes this transformation with another.
Returns:
Expand All @@ -113,7 +121,7 @@ def multiply(self: Self, other: Self) -> Self:

@classmethod
@abc.abstractmethod
def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType:
def exp(cls, tangent: hints.Array) -> Self:
"""Computes `expm(wedge(tangent))`.
Args:
Expand Down Expand Up @@ -149,33 +157,42 @@ def adjoint(self) -> onpt.NDArray[onp.floating]:
"""

@abc.abstractmethod
def inverse(self: Self) -> Self:
def inverse(self) -> Self:
"""Computes the inverse of our transform.
Returns:
Output.
"""

@abc.abstractmethod
def normalize(self: Self) -> Self:
def normalize(self) -> Self:
"""Normalize/projects values and returns.
Returns:
GroupType: Normalized group member.
Normalized group member.
"""

# @classmethod
# @abc.abstractmethod
# def sample_uniform(cls: Type[GroupType], key: hints.KeyArray) -> GroupType:
# """Draw a uniform sample from the group. Translations (if applicable) are in the
# range [-1, 1].
# @classmethod
# @abc.abstractmethod
# def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self:
# """Draw a uniform sample from the group. Translations (if applicable) are in the
# range [-1, 1].
#
# Args:
# key: PRNG key, as returned by `jax.random.PRNGKey()`.
# Args:
# key: PRNG key, as returned by `jax.random.PRNGKey()`.
# batch_axes: Any leading batch axes for the output transforms. Each
# sampled transform will be different.
#
# Returns:
# Sampled group member.
# """
# Returns:
# Sampled group member.
# """

@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`."""
return self.parameters().shape[:-1]


class SOBase(MatrixLieGroup):
Expand All @@ -197,10 +214,10 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup):
@classmethod
@abc.abstractmethod
def from_rotation_and_translation(
cls: Type[SEGroupType],
cls,
rotation: ContainedSOType,
translation: hints.Array,
) -> SEGroupType:
) -> Self:
"""Construct a rigid transform from a rotation and a translation.
Args:
Expand All @@ -213,18 +230,24 @@ def from_rotation_and_translation(

@final
@classmethod
def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
translation=onp.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)

@final
@classmethod
@abc.abstractmethod
def from_translation(
cls: Type[SEGroupType], translation: onpt.NDArray[onp.floating]
) -> SEGroupType:
"""Construct a transform from a translation term."""
def from_translation(cls, translation: hints.Array) -> Self:
# Extract rotation class from type parameter.
assert len(cls.__orig_bases__) == 1 # type: ignore
return cls.from_rotation_and_translation(
rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore
translation=translation,
)

@abc.abstractmethod
def rotation(self) -> ContainedSOType:
Expand All @@ -241,17 +264,17 @@ def translation(self) -> onpt.NDArray[onp.floating]:
def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]:
return self.rotation() @ target + self.translation() # type: ignore

@override
@final
def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType:
@override
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
)

@final
@override
def inverse(self: SEGroupType) -> SEGroupType:
def inverse(self) -> Self:
R_inv = self.rotation().inverse()
return type(self).from_rotation_and_translation(
rotation=R_inv,
Expand All @@ -260,7 +283,7 @@ def inverse(self: SEGroupType) -> SEGroupType:

@final
@override
def normalize(self: SEGroupType) -> SEGroupType:
def normalize(self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation().normalize(),
translation=self.translation(),
Expand Down
Loading

0 comments on commit d01a108

Please sign in to comment.