Skip to content

Commit

Permalink
Add from_rotation() for SE groups, vmap example
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Dec 5, 2021
1 parent 3e4b100 commit 857e552
Show file tree
Hide file tree
Showing 19 changed files with 227 additions and 103 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ Where each group supports:
(<code>jaxlie.<strong>manifold</strong></code>).
- (Un)flattening as pytree nodes.
- Serialization using [flax](https://github.com/google/flax).
- Compatibility with standard JAX function transformations. (we've included some
[examples](./scripts/vmap_example.py) for use with `jax.vmap`)

We also implement various common utilities for things like uniform random
sampling (**`sample_uniform()`**) and converting from/to Euler angles (in the
`SO3` class).

---

##### Install (Python >=3.7)
#### Install (Python >=3.7)

```bash
# Python 3.6 releases also exist, but are no longer being updated.
Expand All @@ -72,19 +74,19 @@ pip install jaxlie

---

##### Example usage for SE(3)
#### Example usage for SE(3)

```python
import numpy as onp

from jaxlie import SE3

#############################
# (1) Constructing transforms
# (1) Constructing transforms.
#############################

# We can compute a w<-b transform by integrating over an se(3) screw, equivalent
# to `SE3.from_matrix(expm(wedge(twist)))`
# to `SE3.from_matrix(expm(wedge(twist)))`.
twist = onp.array([1.0, 0.0, 0.2, 0.0, 0.5, 0.0])
T_w_b = SE3.exp(twist)

Expand All @@ -95,8 +97,8 @@ print(T_w_b.rotation())
print(T_w_b.translation())

# Or the underlying parameters; this is a length-7 (quaternion, translation) array:
print(T_w_b.wxyz_xyz) # SE3-specific field
print(T_w_b.parameters()) # Helper shared by all groups
print(T_w_b.wxyz_xyz) # SE3-specific field.
print(T_w_b.parameters()) # Helper shared by all groups.

# There are also other helpers to generate transforms, eg from matrices:
T_w_b = SE3.from_matrix(T_w_b.as_matrix())
Expand All @@ -112,7 +114,7 @@ T_w_b = SE3(wxyz_xyz=T_w_b.wxyz_xyz)


#############################
# (2) Applying transforms
# (2) Applying transforms.
#############################

# Transform points with the `@` operator:
Expand All @@ -130,7 +132,7 @@ print(p_w)


#############################
# (3) Composing transforms
# (3) Composing transforms.
#############################

# Compose transforms with the `@` operator:
Expand All @@ -144,7 +146,7 @@ print(T_w_a)


#############################
# (4) Misc
# (4) Misc.
#############################

# Compute inverses:
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Source code on `Github <https://github.com/brentyi/jaxlie>`_.
:caption: Example usage

se3_overview
vmap_usage


.. |build| image:: https://github.com/brentyi/jaxlie/workflows/build/badge.svg
Expand Down
7 changes: 7 additions & 0 deletions docs/source/vmap_usage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
`jax.vmap` Usage
==========================================


.. literalinclude:: ../../scripts/vmap_example.py
:language: python

42 changes: 28 additions & 14 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
class MatrixLieGroup(abc.ABC, EnforceOverrides):
"""Interface definition for matrix Lie groups."""

# Class properties
# > These will be set in `_utils.register_lie_group()`
# Class properties.
# > These will be set in `_utils.register_lie_group()`.

matrix_dim: ClassVar[int]
"""Dimension of square matrix output from `.as_matrix()`."""
Expand All @@ -30,14 +30,19 @@ class MatrixLieGroup(abc.ABC, EnforceOverrides):
space_dim: ClassVar[int]
"""Dimension of coordinates that can be transformed."""

def __init__(self, parameters: hints.Vector):
def __init__(
# Notes:
# - For the constructor signature to be consistent with subclasses, `parameters`
# should be marked as positional-only. But this isn't possible in Python 3.7.
# - This method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
self,
parameters: hints.Vector,
):
"""Construct a group object from its underlying parameters."""

# Note that this method is implicitly overriden by the dataclass decorator and
# should _not_ be marked abstract.
raise NotImplementedError()

# Shared implementations
# Shared implementations.

@overload
def __matmul__(self: GroupType, other: GroupType) -> GroupType:
Expand All @@ -55,12 +60,13 @@ def __matmul__(self, other):
"""
if isinstance(other, (onp.ndarray, jnp.ndarray)):
return self.apply(target=other)
if isinstance(other, MatrixLieGroup):
elif isinstance(other, MatrixLieGroup):
assert self.space_dim == other.space_dim
return self.multiply(other=other)
else:
assert False, "Invalid argument"
assert False, f"Invalid argument type for `@` operator: {type(other)}"

# Factory
# Factory.

@classmethod
@abc.abstractmethod
Expand All @@ -83,7 +89,7 @@ def from_matrix(cls: Type[GroupType], matrix: hints.Matrix) -> GroupType:
Group member.
"""

# Accessors
# Accessors.

@abc.abstractmethod
def as_matrix(self) -> hints.MatrixJax:
Expand All @@ -93,7 +99,7 @@ def as_matrix(self) -> hints.MatrixJax:
def parameters(self) -> hints.Vector:
"""Get underlying representation."""

# Operations
# Operations.

@abc.abstractmethod
def apply(self: GroupType, target: hints.Vector) -> hints.VectorJax:
Expand Down Expand Up @@ -195,7 +201,7 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup):
translation vector.
"""

# SE-specific interface
# SE-specific interface.

@classmethod
@abc.abstractmethod
Expand All @@ -214,6 +220,14 @@ def from_rotation_and_translation(
Constructed transformation.
"""

@final
@classmethod
def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
)

@abc.abstractmethod
def rotation(self) -> ContainedSOType:
"""Returns a transform's rotation term."""
Expand All @@ -222,7 +236,7 @@ def rotation(self) -> ContainedSOType:
def translation(self) -> hints.Vector:
"""Returns a transform's translation term."""

# Overrides
# Overrides.

@final
@overrides
Expand Down
24 changes: 12 additions & 12 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SE2(_base.SEBase[SO2]):
vy, omega)`.
"""

# SE2-specific
# SE2-specific.

unit_complex_xy: hints.Vector
"""Internal parameters. `(cos, sin, x, y)`."""
Expand All @@ -44,7 +44,7 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2
sin = jnp.sin(theta)
return SE2(unit_complex_xy=jnp.array([cos, sin, x, y]))

# SE-specific
# SE-specific.

@staticmethod
@overrides
Expand All @@ -65,7 +65,7 @@ def rotation(self) -> SO2:
def translation(self) -> hints.Vector:
return self.unit_complex_xy[..., 2:]

# Factory
# Factory.

@staticmethod
@overrides
Expand All @@ -76,13 +76,13 @@ def identity() -> "SE2":
@overrides
def from_matrix(matrix: hints.Matrix) -> "SE2":
assert matrix.shape == (3, 3)
# Currently assumes bottom row is [0, 0, 1]
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
rotation=SO2.from_matrix(matrix[:2, :2]),
translation=matrix[:2, 2],
)

# Accessors
# Accessors.

@overrides
def parameters(self) -> hints.Vector:
Expand All @@ -99,7 +99,7 @@ def as_matrix(self) -> hints.MatrixJax:
]
)

# Operations
# Operations.

@staticmethod
@overrides
Expand All @@ -115,10 +115,10 @@ def exp(tangent: hints.TangentVector) -> "SE2":
use_taylor = jnp.abs(theta) < get_epsilon(tangent.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
# reverse-mode AD.
safe_theta = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
1.0, # Any non-zero value should do here.
theta,
)

Expand Down Expand Up @@ -160,18 +160,18 @@ def log(self: "SE2") -> hints.TangentVectorJax:
use_taylor = jnp.abs(cos_minus_one) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
# reverse-mode AD.
safe_cos_minus_one = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
1.0, # Any non-zero value should do here.
cos_minus_one,
)

half_theta_over_tan_half_theta = jnp.where(
use_taylor,
# Taylor approximation
# Taylor approximation.
1.0 - (theta ** 2) / 12.0,
# Default
# Default.
-(half_theta * jnp.sin(theta)) / safe_cos_minus_one,
)

Expand Down
20 changes: 10 additions & 10 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SE3(_base.SEBase[SO3]):
is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
"""

# SE3-specific
# SE3-specific.

wxyz_xyz: hints.Vector
"""Internal parameters. wxyz quaternion followed by xyz translation."""
Expand All @@ -47,7 +47,7 @@ def __repr__(self) -> str:
trans = jnp.round(self.wxyz_xyz[..., 4:], 5)
return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})"

# SE-specific
# SE-specific.

@staticmethod
@overrides
Expand All @@ -66,7 +66,7 @@ def rotation(self) -> SO3:
def translation(self) -> hints.Vector:
return self.wxyz_xyz[..., 4:]

# Factory
# Factory.

@staticmethod
@overrides
Expand All @@ -77,13 +77,13 @@ def identity() -> "SE3":
@overrides
def from_matrix(matrix: hints.Matrix) -> "SE3":
assert matrix.shape == (4, 4)
# Currently assumes bottom row is [0, 0, 0, 1]
# Currently assumes bottom row is [0, 0, 0, 1].
return SE3.from_rotation_and_translation(
rotation=SO3.from_matrix(matrix[:3, :3]),
translation=matrix[:3, 3],
)

# Accessors
# Accessors.

@overrides
def as_matrix(self) -> hints.MatrixJax:
Expand All @@ -99,7 +99,7 @@ def as_matrix(self) -> hints.MatrixJax:
def parameters(self) -> hints.Vector:
return self.wxyz_xyz

# Operations
# Operations.

@staticmethod
@overrides
Expand All @@ -116,10 +116,10 @@ def exp(tangent: hints.TangentVector) -> "SE3":
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
# reverse-mode AD.
theta_squared_safe = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
1.0, # Any non-zero value should do here.
theta_squared,
)
del theta_squared
Expand Down Expand Up @@ -154,10 +154,10 @@ def log(self: "SE3") -> hints.TangentVectorJax:
skew_omega = _skew(omega)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
# reverse-mode AD.
theta_squared_safe = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
1.0, # Any non-zero value should do here.
theta_squared,
)
del theta_squared
Expand Down
Loading

0 comments on commit 857e552

Please sign in to comment.