Skip to content

Commit

Permalink
Rename annotations => hints
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Apr 6, 2021
1 parent 6985c73 commit e989f76
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 82 deletions.
4 changes: 2 additions & 2 deletions jaxlie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from . import annotations, manifold, utils
from . import hints, manifold, utils
from ._base import MatrixLieGroup, SEBase, SOBase
from ._se2 import SE2
from ._se3 import SE3
from ._so2 import SO2
from ._so3 import SO3

__all__ = [
"annotations",
"hints",
"manifold",
"utils",
"MatrixLieGroup",
Expand Down
40 changes: 20 additions & 20 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
from overrides import EnforceOverrides, final, overrides

from . import annotations
from . import hints

GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
SEGroupType = TypeVar("SEGroupType", bound="SEBase")
Expand Down Expand Up @@ -44,7 +44,7 @@ def __matmul__(self: GroupType, other: GroupType) -> GroupType:
...

@overload
def __matmul__(self: GroupType, other: annotations.Vector) -> annotations.Vector:
def __matmul__(self: GroupType, other: hints.Vector) -> hints.Vector:
...

def __matmul__(self, other):
Expand All @@ -68,16 +68,16 @@ def identity(cls: Type[GroupType]) -> GroupType:
"""Returns identity element.
Returns:
annotations.Matrix: Identity.
hints.Matrix: Identity.
"""

@classmethod
@abc.abstractmethod
def from_matrix(cls: Type[GroupType], matrix: annotations.Matrix) -> GroupType:
def from_matrix(cls: Type[GroupType], matrix: hints.Matrix) -> GroupType:
"""Get group member from matrix representation.
Args:
matrix (jnp.ndarray): annotations.Matrix representaiton.
matrix (jnp.ndarray): hints.Matrix representaiton.
Returns:
GroupType: Group member.
Expand All @@ -86,24 +86,24 @@ def from_matrix(cls: Type[GroupType], matrix: annotations.Matrix) -> GroupType:
# Accessors

@abc.abstractmethod
def as_matrix(self) -> annotations.Matrix:
def as_matrix(self) -> hints.Matrix:
"""Get transformation as a matrix. Homogeneous for SE groups."""

@abc.abstractmethod
def parameters(self) -> annotations.Vector:
def parameters(self) -> hints.Vector:
"""Get underlying representation."""

# Operations

@abc.abstractmethod
def apply(self: GroupType, target: annotations.Vector) -> annotations.Vector:
def apply(self: GroupType, target: hints.Vector) -> hints.Vector:
"""Applies the group action.
Args:
target (annotations.Vector): annotations.Vector to transform.
target (hints.Vector): hints.Vector to transform.
Returns:
annotations.Vector: Transformed vector.
hints.Vector: Transformed vector.
"""

@abc.abstractmethod
Expand All @@ -119,26 +119,26 @@ def multiply(self: GroupType, other: GroupType) -> GroupType:

@classmethod
@abc.abstractmethod
def exp(cls: Type[GroupType], tangent: annotations.TangentVector) -> GroupType:
def exp(cls: Type[GroupType], tangent: hints.TangentVector) -> GroupType:
"""Computes `expm(wedge(tangent))`.
Args:
tangent (annotations.TangentVector): Input.
tangent (hints.TangentVector): Input.
Returns:
MatrixLieGroup: Output.
"""

@abc.abstractmethod
def log(self: GroupType) -> annotations.TangentVector:
def log(self: GroupType) -> hints.TangentVector:
"""Computes `vee(logm(transformation matrix))`.
Returns:
annotations.TangentVector: Output. Shape should be `(tangent_dim,)`.
hints.TangentVector: Output. Shape should be `(tangent_dim,)`.
"""

@abc.abstractmethod
def adjoint(self: GroupType) -> annotations.Matrix:
def adjoint(self: GroupType) -> hints.Matrix:
"""Computes the adjoint, which transforms tangent vectors between tangent spaces.
More precisely, for a transform `GroupType`:
Expand All @@ -150,15 +150,15 @@ def adjoint(self: GroupType) -> annotations.Matrix:
between our spatial and body representations.
Returns:
annotations.Matrix: Output. Shape should be `(tangent_dim, tangent_dim)`.
hints.Matrix: Output. Shape should be `(tangent_dim, tangent_dim)`.
"""

@abc.abstractmethod
def inverse(self: GroupType) -> GroupType:
"""Computes the inverse of our transform.
Returns:
annotations.Matrix: Output.
hints.Matrix: Output.
"""

@abc.abstractmethod
Expand Down Expand Up @@ -195,7 +195,7 @@ class SEBase(MatrixLieGroup):
@abc.abstractmethod
def from_rotation_and_translation(
rotation: SOBase,
translation: annotations.Vector,
translation: hints.Vector,
) -> SEGroupType:
"""Construct a rigid transform from a rotation and a translation."""

Expand All @@ -204,14 +204,14 @@ def rotation(self) -> SOBase:
"""Returns a transform's rotation term."""

@abc.abstractmethod
def translation(self) -> annotations.Vector:
def translation(self) -> hints.Vector:
"""Returns a transform's translation term."""

# Overrides

@final
@overrides
def apply(self, target: annotations.Vector) -> annotations.Vector:
def apply(self, target: hints.Vector) -> hints.Vector:
return self.rotation() @ target + self.translation()

@final
Expand Down
24 changes: 11 additions & 13 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import numpy as jnp
from overrides import overrides

from . import _base, annotations
from . import _base, hints
from ._so2 import SO2
from .utils import get_epsilon, register_lie_group

Expand All @@ -21,7 +21,7 @@ class SE2(_base.SEBase):

# SE2-specific

unit_complex_xy: annotations.Vector
unit_complex_xy: hints.Vector
"""Internal parameters. `(cos, sin, x, y)`."""

@overrides
Expand All @@ -31,9 +31,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})"

@staticmethod
def from_xy_theta(
x: annotations.Scalar, y: annotations.Scalar, theta: annotations.Scalar
) -> "SE2":
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2":
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SE2(unit_complex_xy=jnp.array([cos, sin, x, y]))
Expand All @@ -44,7 +42,7 @@ def from_xy_theta(
@overrides
def from_rotation_and_translation(
rotation: SO2,
translation: annotations.Vector,
translation: hints.Vector,
) -> "SE2":
assert translation.shape == (2,)
return SE2(
Expand All @@ -56,7 +54,7 @@ def rotation(self) -> SO2:
return SO2(unit_complex=self.unit_complex_xy[..., :2])

@overrides
def translation(self) -> annotations.Vector:
def translation(self) -> hints.Vector:
return self.unit_complex_xy[..., 2:]

# Factory
Expand All @@ -68,7 +66,7 @@ def identity() -> "SE2":

@staticmethod
@overrides
def from_matrix(matrix: annotations.Matrix) -> "SE2":
def from_matrix(matrix: hints.Matrix) -> "SE2":
assert matrix.shape == (3, 3)
# Currently assumes bottom row is [0, 0, 1]
return SE2.from_rotation_and_translation(
Expand All @@ -79,11 +77,11 @@ def from_matrix(matrix: annotations.Matrix) -> "SE2":
# Accessors

@overrides
def parameters(self) -> annotations.Vector:
def parameters(self) -> hints.Vector:
return self.unit_complex_xy

@overrides
def as_matrix(self) -> annotations.Matrix:
def as_matrix(self) -> hints.Matrix:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
[
Expand All @@ -97,7 +95,7 @@ def as_matrix(self) -> annotations.Matrix:

@staticmethod
@overrides
def exp(tangent: annotations.TangentVector) -> "SE2":
def exp(tangent: hints.TangentVector) -> "SE2":
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558
# Also see:
Expand Down Expand Up @@ -131,7 +129,7 @@ def exp(tangent: annotations.TangentVector) -> "SE2":
)

@overrides
def log(self: "SE2") -> annotations.TangentVector:
def log(self: "SE2") -> hints.TangentVector:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160
# Also see:
Expand Down Expand Up @@ -162,7 +160,7 @@ def log(self: "SE2") -> annotations.TangentVector:
return tangent

@overrides
def adjoint(self: "SE2") -> annotations.Matrix:
def adjoint(self: "SE2") -> hints.Matrix:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
[
Expand Down
22 changes: 11 additions & 11 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from jax import numpy as jnp
from overrides import overrides

from . import _base, annotations
from . import _base, hints
from ._so3 import SO3
from .utils import get_epsilon, register_lie_group


def _skew(omega: annotations.Vector) -> annotations.Matrix:
def _skew(omega: hints.Vector) -> hints.Matrix:
"""Returns the skew-symmetric form of a length-3 vector. """

wx, wy, wz = omega
Expand All @@ -34,7 +34,7 @@ class SE3(_base.SEBase):

# SE3-specific

wxyz_xyz: annotations.Vector
wxyz_xyz: hints.Vector
"""Internal parameters. wxyz quaternion followed by xyz translation."""

@overrides
Expand All @@ -49,7 +49,7 @@ def __repr__(self) -> str:
@overrides
def from_rotation_and_translation(
rotation: SO3,
translation: annotations.Vector,
translation: hints.Vector,
) -> "SE3":
assert translation.shape == (3,)
return SE3(wxyz_xyz=jnp.concatenate([rotation.wxyz, translation]))
Expand All @@ -59,7 +59,7 @@ def rotation(self) -> SO3:
return SO3(wxyz=self.wxyz_xyz[..., :4])

@overrides
def translation(self) -> annotations.Vector:
def translation(self) -> hints.Vector:
return self.wxyz_xyz[..., 4:]

# Factory
Expand All @@ -71,7 +71,7 @@ def identity() -> "SE3":

@staticmethod
@overrides
def from_matrix(matrix: annotations.Matrix) -> "SE3":
def from_matrix(matrix: hints.Matrix) -> "SE3":
assert matrix.shape == (4, 4)
# Currently assumes bottom row is [0, 0, 0, 1]
return SE3.from_rotation_and_translation(
Expand All @@ -82,7 +82,7 @@ def from_matrix(matrix: annotations.Matrix) -> "SE3":
# Accessors

@overrides
def as_matrix(self) -> annotations.Matrix:
def as_matrix(self) -> hints.Matrix:
return (
jnp.eye(4)
.at[:3, :3]
Expand All @@ -92,14 +92,14 @@ def as_matrix(self) -> annotations.Matrix:
)

@overrides
def parameters(self) -> annotations.Vector:
def parameters(self) -> hints.Vector:
return self.wxyz_xyz

# Operations

@staticmethod
@overrides
def exp(tangent: annotations.TangentVector) -> "SE3":
def exp(tangent: hints.TangentVector) -> "SE3":
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761

Expand Down Expand Up @@ -130,7 +130,7 @@ def exp(tangent: annotations.TangentVector) -> "SE3":
)

@overrides
def log(self: "SE3") -> annotations.TangentVector:
def log(self: "SE3") -> hints.TangentVector:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
omega = self.rotation().log()
Expand All @@ -153,7 +153,7 @@ def log(self: "SE3") -> annotations.TangentVector:
return jnp.concatenate([V_inv @ self.translation(), omega])

@overrides
def adjoint(self: "SE3") -> annotations.Matrix:
def adjoint(self: "SE3") -> hints.Matrix:
R = self.rotation().as_matrix()
return jnp.block(
[
Expand Down
Loading

0 comments on commit e989f76

Please sign in to comment.