Skip to content

Commit

Permalink
Various typing tweaks, runtime shape checks
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 3, 2022
1 parent 857e552 commit c1ccd41
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 174 deletions.
34 changes: 21 additions & 13 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import ClassVar, Generic, Type, TypeVar, overload
from typing import ClassVar, Generic, Tuple, Type, TypeVar, overload

import jax
import numpy as onp
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
# - This method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
self,
parameters: hints.Vector,
parameters: jnp.ndarray,
):
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()
Expand All @@ -49,7 +49,7 @@ def __matmul__(self: GroupType, other: GroupType) -> GroupType:
...

@overload
def __matmul__(self: GroupType, other: hints.Vector) -> hints.VectorJax:
def __matmul__(self: GroupType, other: hints.Array) -> jnp.ndarray:
...

def __matmul__(self, other):
Expand Down Expand Up @@ -79,7 +79,7 @@ def identity(cls: Type[GroupType]) -> GroupType:

@classmethod
@abc.abstractmethod
def from_matrix(cls: Type[GroupType], matrix: hints.Matrix) -> GroupType:
def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType:
"""Get group member from matrix representation.
Args:
Expand All @@ -92,17 +92,17 @@ def from_matrix(cls: Type[GroupType], matrix: hints.Matrix) -> GroupType:
# Accessors.

@abc.abstractmethod
def as_matrix(self) -> hints.MatrixJax:
def as_matrix(self) -> jnp.ndarray:
"""Get transformation as a matrix. Homogeneous for SE groups."""

@abc.abstractmethod
def parameters(self) -> hints.Vector:
def parameters(self) -> jnp.ndarray:
"""Get underlying representation."""

# Operations.

@abc.abstractmethod
def apply(self: GroupType, target: hints.Vector) -> hints.VectorJax:
def apply(self: GroupType, target: hints.Array) -> jnp.ndarray:
"""Applies group action to a point.
Args:
Expand All @@ -122,7 +122,7 @@ def multiply(self: GroupType, other: GroupType) -> GroupType:

@classmethod
@abc.abstractmethod
def exp(cls: Type[GroupType], tangent: hints.TangentVector) -> GroupType:
def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType:
"""Computes `expm(wedge(tangent))`.
Args:
Expand All @@ -133,15 +133,15 @@ def exp(cls: Type[GroupType], tangent: hints.TangentVector) -> GroupType:
"""

@abc.abstractmethod
def log(self: GroupType) -> hints.TangentVectorJax:
def log(self: GroupType) -> jnp.ndarray:
"""Computes `vee(logm(transformation matrix))`.
Returns:
Output. Shape should be `(tangent_dim,)`.
"""

@abc.abstractmethod
def adjoint(self: GroupType) -> hints.MatrixJax:
def adjoint(self: GroupType) -> jnp.ndarray:
"""Computes the adjoint, which transforms tangent vectors between tangent
spaces.
Expand Down Expand Up @@ -186,6 +186,14 @@ def sample_uniform(cls: Type[GroupType], key: jax.random.KeyArray) -> GroupType:
Sampled group member.
"""

@abc.abstractmethod
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`.
This should generally be implemented by `jdc.EnforcedAnnotationsMixin`."""


class SOBase(MatrixLieGroup):
"""Base class for special orthogonal groups."""
Expand All @@ -208,7 +216,7 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup):
def from_rotation_and_translation(
cls: Type[SEGroupType],
rotation: ContainedSOType,
translation: hints.Vector,
translation: hints.Array,
) -> SEGroupType:
"""Construct a rigid transform from a rotation and a translation.
Expand All @@ -233,14 +241,14 @@ def rotation(self) -> ContainedSOType:
"""Returns a transform's rotation term."""

@abc.abstractmethod
def translation(self) -> hints.Vector:
def translation(self) -> jnp.ndarray:
"""Returns a transform's translation term."""

# Overrides.

@final
@overrides
def apply(self, target: hints.Vector) -> hints.VectorJax:
def apply(self, target: hints.Array) -> jnp.ndarray:
return self.rotation() @ target + self.translation() # type: ignore

@final
Expand Down
32 changes: 18 additions & 14 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import jax
import jax_dataclasses
import numpy as onp
import jax_dataclasses as jdc
from jax import numpy as jnp
from overrides import overrides
from typing_extensions import Annotated

from . import _base, hints
from ._so2 import SO2
Expand All @@ -15,8 +15,8 @@
tangent_dim=3,
space_dim=2,
)
@jax_dataclasses.pytree_dataclass
class SE2(_base.SEBase[SO2]):
@jdc.pytree_dataclass
class SE2(jdc.EnforcedAnnotationsMixin, _base.SEBase[SO2]):
"""Special Euclidean group for proper rigid transforms in 2D.
Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx,
Expand All @@ -25,7 +25,11 @@ class SE2(_base.SEBase[SO2]):

# SE2-specific.

unit_complex_xy: hints.Vector
unit_complex_xy: Annotated[
jnp.ndarray,
(4,), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. `(cos, sin, x, y)`."""

@overrides
Expand All @@ -50,7 +54,7 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2
@overrides
def from_rotation_and_translation(
rotation: SO2,
translation: hints.Vector,
translation: hints.Array,
) -> "SE2":
assert translation.shape == (2,)
return SE2(
Expand All @@ -62,19 +66,19 @@ def rotation(self) -> SO2:
return SO2(unit_complex=self.unit_complex_xy[..., :2])

@overrides
def translation(self) -> hints.Vector:
def translation(self) -> jnp.ndarray:
return self.unit_complex_xy[..., 2:]

# Factory.

@staticmethod
@overrides
def identity() -> "SE2":
return SE2(unit_complex_xy=onp.array([1.0, 0.0, 0.0, 0.0]))
return SE2(unit_complex_xy=jnp.array([1.0, 0.0, 0.0, 0.0]))

@staticmethod
@overrides
def from_matrix(matrix: hints.Matrix) -> "SE2":
def from_matrix(matrix: hints.Array) -> "SE2":
assert matrix.shape == (3, 3)
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
Expand All @@ -85,11 +89,11 @@ def from_matrix(matrix: hints.Matrix) -> "SE2":
# Accessors.

@overrides
def parameters(self) -> hints.Vector:
def parameters(self) -> jnp.ndarray:
return self.unit_complex_xy

@overrides
def as_matrix(self) -> hints.MatrixJax:
def as_matrix(self) -> jnp.ndarray:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
[
Expand All @@ -103,7 +107,7 @@ def as_matrix(self) -> hints.MatrixJax:

@staticmethod
@overrides
def exp(tangent: hints.TangentVector) -> "SE2":
def exp(tangent: hints.Array) -> "SE2":
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558
# Also see:
Expand Down Expand Up @@ -146,7 +150,7 @@ def exp(tangent: hints.TangentVector) -> "SE2":
)

@overrides
def log(self: "SE2") -> hints.TangentVectorJax:
def log(self: "SE2") -> jnp.ndarray:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160
# Also see:
Expand Down Expand Up @@ -186,7 +190,7 @@ def log(self: "SE2") -> hints.TangentVectorJax:
return tangent

@overrides
def adjoint(self: "SE2") -> hints.MatrixJax:
def adjoint(self: "SE2") -> jnp.ndarray:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
[
Expand Down
34 changes: 19 additions & 15 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import jax
import jax_dataclasses
import numpy as onp
import jax_dataclasses as jdc
from jax import numpy as jnp
from overrides import overrides
from typing_extensions import Annotated

from . import _base, hints
from ._so3 import SO3
from .utils import get_epsilon, register_lie_group


def _skew(omega: hints.Vector) -> hints.MatrixJax:
def _skew(omega: hints.Array) -> jnp.ndarray:
"""Returns the skew-symmetric form of a length-3 vector."""

wx, wy, wz = omega
Expand All @@ -28,8 +28,8 @@ def _skew(omega: hints.Vector) -> hints.MatrixJax:
tangent_dim=6,
space_dim=3,
)
@jax_dataclasses.pytree_dataclass
class SE3(_base.SEBase[SO3]):
@jdc.pytree_dataclass
class SE3(jdc.EnforcedAnnotationsMixin, _base.SEBase[SO3]):
"""Special Euclidean group for proper rigid transforms in 3D.
Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization
Expand All @@ -38,7 +38,11 @@ class SE3(_base.SEBase[SO3]):

# SE3-specific.

wxyz_xyz: hints.Vector
wxyz_xyz: Annotated[
jnp.ndarray,
(7,), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. wxyz quaternion followed by xyz translation."""

@overrides
Expand All @@ -53,7 +57,7 @@ def __repr__(self) -> str:
@overrides
def from_rotation_and_translation(
rotation: SO3,
translation: hints.Vector,
translation: hints.Array,
) -> "SE3":
assert translation.shape == (3,)
return SE3(wxyz_xyz=jnp.concatenate([rotation.wxyz, translation]))
Expand All @@ -63,19 +67,19 @@ def rotation(self) -> SO3:
return SO3(wxyz=self.wxyz_xyz[..., :4])

@overrides
def translation(self) -> hints.Vector:
def translation(self) -> jnp.ndarray:
return self.wxyz_xyz[..., 4:]

# Factory.

@staticmethod
@overrides
def identity() -> "SE3":
return SE3(wxyz_xyz=onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))
return SE3(wxyz_xyz=jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))

@staticmethod
@overrides
def from_matrix(matrix: hints.Matrix) -> "SE3":
def from_matrix(matrix: hints.Array) -> "SE3":
assert matrix.shape == (4, 4)
# Currently assumes bottom row is [0, 0, 0, 1].
return SE3.from_rotation_and_translation(
Expand All @@ -86,7 +90,7 @@ def from_matrix(matrix: hints.Matrix) -> "SE3":
# Accessors.

@overrides
def as_matrix(self) -> hints.MatrixJax:
def as_matrix(self) -> jnp.ndarray:
return (
jnp.eye(4)
.at[:3, :3]
Expand All @@ -96,14 +100,14 @@ def as_matrix(self) -> hints.MatrixJax:
)

@overrides
def parameters(self) -> hints.Vector:
def parameters(self) -> jnp.ndarray:
return self.wxyz_xyz

# Operations.

@staticmethod
@overrides
def exp(tangent: hints.TangentVector) -> "SE3":
def exp(tangent: hints.Array) -> "SE3":
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761

Expand Down Expand Up @@ -144,7 +148,7 @@ def exp(tangent: hints.TangentVector) -> "SE3":
)

@overrides
def log(self: "SE3") -> hints.TangentVectorJax:
def log(self: "SE3") -> jnp.ndarray:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
omega = self.rotation().log()
Expand Down Expand Up @@ -183,7 +187,7 @@ def log(self: "SE3") -> hints.TangentVectorJax:
return jnp.concatenate([V_inv @ self.translation(), omega])

@overrides
def adjoint(self: "SE3") -> hints.MatrixJax:
def adjoint(self: "SE3") -> jnp.ndarray:
R = self.rotation().as_matrix()
return jnp.block(
[
Expand Down
Loading

0 comments on commit c1ccd41

Please sign in to comment.