Skip to content

Commit

Permalink
Housekeeping!
Browse files Browse the repository at this point in the history
- Minor numerical stability improvmeents
- Typing improvements (relies on JAX version bump)
- Dependency cleanup
- Test infra tweaks
- Docs, README nits
  • Loading branch information
brentyi committed Nov 30, 2021
1 parent 4dbe16f commit 3e4b100
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 101 deletions.
36 changes: 20 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
**[ [API reference](https://brentyi.github.io/jaxlie) ]** **[
[PyPI](https://pypi.org/project/jaxlie/) ]**

`jaxlie` is a Lie theory library for rigid body transformations and optimization
in JAX.
`jaxlie` is a library containing implementations of Lie groups commonly used for
rigid body transformations, targeted at computer vision & robotics
applications written in JAX. Heavily inspired by the C++ library
[Sophus](https://github.com/strasdat/Sophus).

Implements Lie groups as high-level (data)classes:
We implement Lie groups as high-level (data)classes:

<table>
<thead>
Expand Down Expand Up @@ -45,24 +47,26 @@ Implements Lie groups as high-level (data)classes:
</tbody>
</table>

Each group supports:
Where each group supports:

- Forward- and reverse-mode AD-friendly **`exp()`**, **`log()`**,
**`adjoint()`**, **`apply`**, **`multiply()`**, **`inverse()`**, and
**`identity()`** operations
- Helpers + analytical Jacobians for on-manifold optimization
(<code>jaxlie.<strong>manifold</strong></code>)
- (Un)flattening as pytree nodes
- Serialization using [flax](https://github.com/google/flax)

Heavily inspired by (and some operations ported from) the C++ library
[Sophus](https://github.com/strasdat/Sophus).
**`adjoint()`**, **`apply()`**, **`multiply()`**, **`inverse()`**,
**`identity()`**, **`from_matrix()`**, and **`as_matrix()`** operations.
- Helpers + analytical Jacobians for manifold optimization
(<code>jaxlie.<strong>manifold</strong></code>).
- (Un)flattening as pytree nodes.
- Serialization using [flax](https://github.com/google/flax).

We also implement various common utilities for things like uniform random
sampling (**`sample_uniform()`**) and converting from/to Euler angles (in the
`SO3` class).

---

##### Install (Python >=3.6)
##### Install (Python >=3.7)

```bash
# Python 3.6 releases also exist, but are no longer being updated.
pip install jaxlie
```

Expand Down Expand Up @@ -153,6 +157,6 @@ adjoint_T_w_b = T_w_b.adjoint()
print(adjoint_T_w_b)

# Recover our twist, equivalent to `vee(logm(T_w_b.as_matrix()))`:
twist = T_w_b.log()
print(twist)
twist_recovered = T_w_b.log()
print(twist_recovered)
```
17 changes: 10 additions & 7 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from typing import ClassVar, Generic, Type, TypeVar, overload

import jax_dataclasses
import jax
import numpy as onp
from jax import numpy as jnp
from overrides import EnforceOverrides, final, overrides
Expand Down Expand Up @@ -136,15 +136,16 @@ def log(self: GroupType) -> hints.TangentVectorJax:

@abc.abstractmethod
def adjoint(self: GroupType) -> hints.MatrixJax:
"""Computes the adjoint, which transforms tangent vectors between tangent spaces.
"""Computes the adjoint, which transforms tangent vectors between tangent
spaces.
More precisely, for a transform `GroupType`:
```
GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType
```
In robotics, typically used for converting twists, wrenches, and Jacobians
between our spatial and body representations.
In robotics, typically used for transforming twists, wrenches, and Jacobians
across different reference frames.
Returns:
Output. Shape should be `(tangent_dim, tangent_dim)`.
Expand All @@ -168,8 +169,9 @@ def normalize(self: GroupType) -> GroupType:

@classmethod
@abc.abstractmethod
def sample_uniform(cls: Type[GroupType], key: jnp.ndarray) -> GroupType:
"""Draw a uniform sample from the group. Translations are in the range [-1, 1].
def sample_uniform(cls: Type[GroupType], key: jax.random.KeyArray) -> GroupType:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Args:
key: PRNG key, as returned by `jax.random.PRNGKey()`.
Expand All @@ -190,7 +192,8 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup):
"""Base class for special Euclidean groups.
Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional
translation vector."""
translation vector.
"""

# SE-specific interface

Expand Down
12 changes: 7 additions & 5 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
class SE2(_base.SEBase[SO2]):
"""Special Euclidean group for proper rigid transforms in 2D.
Internal parameterization is `(cos, sin, x, y)`.
Tangent parameterization is `(vx, vy, omega)`.
Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx,
vy, omega)`.
"""

# SE2-specific
Expand All @@ -36,8 +36,10 @@ def __repr__(self) -> str:

@staticmethod
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2":
"""Construct a transformation from standard 2D pose parameters. Note that this
is not the same as integrating over a length-3 twist."""
"""Construct a transformation from standard 2D pose parameters.
Note that this is not the same as integrating over a length-3 twist.
"""
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SE2(unit_complex_xy=jnp.array([cos, sin, x, y]))
Expand Down Expand Up @@ -196,7 +198,7 @@ def adjoint(self: "SE2") -> hints.MatrixJax:

@staticmethod
@overrides
def sample_uniform(key: jnp.ndarray) -> "SE2":
def sample_uniform(key: jax.random.KeyArray) -> "SE2":
key0, key1 = jax.random.split(key)
return SE2.from_rotation_and_translation(
rotation=SO2.sample_uniform(key0),
Expand Down
6 changes: 3 additions & 3 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _skew(omega: hints.Vector) -> hints.MatrixJax:
class SE3(_base.SEBase[SO3]):
"""Special Euclidean group for proper rigid transforms in 3D.
Internal parameterization is `(qw, qx, qy, qz, x, y, z)`.
Tangent parameterization is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization
is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
"""

# SE3-specific
Expand Down Expand Up @@ -194,7 +194,7 @@ def adjoint(self: "SE3") -> hints.MatrixJax:

@staticmethod
@overrides
def sample_uniform(key: jnp.ndarray) -> "SE3":
def sample_uniform(key: jax.random.KeyArray) -> "SE3":
key0, key1 = jax.random.split(key)
return SE3.from_rotation_and_translation(
rotation=SO3.sample_uniform(key0),
Expand Down
5 changes: 2 additions & 3 deletions jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
class SO2(_base.SOBase):
"""Special orthogonal group for 2D rotations.
Internal parameterization is `(cos, sin)`.
Tangent parameterization is `(omega,)`.
Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`.
"""

# SO2-specific
Expand Down Expand Up @@ -112,7 +111,7 @@ def normalize(self: "SO2") -> "SO2":

@staticmethod
@overrides
def sample_uniform(key: jnp.ndarray) -> "SO2":
def sample_uniform(key: jax.random.KeyArray) -> "SO2":
return SO2.from_radians(
jax.random.uniform(key=key, minval=0.0, maxval=2.0 * jnp.pi)
)
28 changes: 14 additions & 14 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
class SO3(_base.SOBase):
"""Special orthogonal group for 3D rotations.
Internal parameterization is `(qw, qx, qy, qz)`.
Tangent parameterization is `(omega_x, omega_y, omega_z)`.
Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is
`(omega_x, omega_y, omega_z)`.
"""

# SO3-specific
Expand Down Expand Up @@ -74,8 +74,8 @@ def from_rpy_radians(
pitch: hints.Scalar,
yaw: hints.Scalar,
) -> "SO3":
"""Generates a transform from a set of Euler angles.
Uses the ZYX mobile robot convention.
"""Generates a transform from a set of Euler angles. Uses the ZYX mobile robot
convention.
Args:
roll: X rotation, in radians. Applied first.
Expand Down Expand Up @@ -112,8 +112,7 @@ def as_quaternion_xyzw(self) -> hints.VectorJax:
return jnp.roll(self.wxyz, shift=-1)

def as_rpy_radians(self) -> hints.RollPitchYaw:
"""Computes roll, pitch, and yaw angles.
Uses the ZYX mobile robot convention.
"""Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention.
Returns:
Named tuple containing Euler angles in radians.
Expand All @@ -125,8 +124,7 @@ def as_rpy_radians(self) -> hints.RollPitchYaw:
)

def compute_roll_radians(self) -> hints.ScalarJax:
"""Compute roll angle.
Uses the ZYX mobile robot convention.
"""Compute roll angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
Expand All @@ -136,8 +134,7 @@ def compute_roll_radians(self) -> hints.ScalarJax:
return jnp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 ** 2 + q2 ** 2))

def compute_pitch_radians(self) -> hints.ScalarJax:
"""Compute pitch angle.
Uses the ZYX mobile robot convention.
"""Compute pitch angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
Expand All @@ -147,8 +144,7 @@ def compute_pitch_radians(self) -> hints.ScalarJax:
return jnp.arcsin(2 * (q0 * q2 - q3 * q1))

def compute_yaw_radians(self) -> hints.ScalarJax:
"""Compute yaw angle.
Uses the ZYX mobile robot convention.
"""Compute yaw angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
Expand Down Expand Up @@ -342,13 +338,17 @@ def log(self: "SO3") -> hints.TangentVectorJax:
)
)

atan_n_over_w = jnp.arctan2(
jnp.where(w < 0, -norm_safe, norm_safe),
jnp.abs(w),
)
atan_factor = jnp.where(
use_taylor,
2.0 / w - 2.0 / 3.0 * norm_sq / (w ** 3),
jnp.where(
jnp.abs(w) < get_epsilon(w.dtype),
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
2.0 * jnp.arctan(norm_safe / w) / norm_safe,
2.0 * atan_n_over_w / norm_safe,
),
)

Expand All @@ -369,7 +369,7 @@ def normalize(self: "SO3") -> "SO3":

@staticmethod
@overrides
def sample_uniform(key: jnp.ndarray) -> "SO3":
def sample_uniform(key: jax.random.KeyArray) -> "SO3":
# Uniformly sample over S^4
# > Reference: http://planning.cs.uiuc.edu/node198.html
u1, u2, u3 = jax.random.uniform(
Expand Down
15 changes: 5 additions & 10 deletions jaxlie/hints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,26 @@
# Type aliases for JAX/Numpy arrays; primarily for function inputs

Array = Union[onp.ndarray, jnp.ndarray]
"""Type alias for `Union[jnp.ndarray, onp.ndarray]`.
"""
"""Type alias for `Union[jnp.ndarray, onp.ndarray]`."""

Scalar = Union[float, Array]
"""Type alias for `Union[float, Array]`.
"""
"""Type alias for `Union[float, Array]`."""

Matrix = Array
"""Type alias for `Array`. Should not be instantiated.
Refers to a square matrix, typically with shape `(Group.matrix_dim, Group.matrix_dim)`.
For adjoints, shape should be `(Group.tangent_dim, Group.tangent_dim)`.
"""
For adjoints, shape should be `(Group.tangent_dim, Group.tangent_dim)`."""

Vector = Array
"""Type alias for `Array`. Should not be instantiated.
Refers to a general 1D array.
"""
Refers to a general 1D array."""

TangentVector = Array
"""Type alias for `Array`. Should not be instantiated.
Refers to a 1D array with shape `(Group.tangent_dim,)`.
"""
Refers to a 1D array with shape `(Group.tangent_dim,)`."""

# Type aliases for JAX arrays; primarily for function outputs

Expand Down
1 change: 1 addition & 0 deletions jaxlie/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def register_lie_group(
space_dim: int,
) -> Callable[[Type[T]], Type[T]]:
"""Decorator for registering Lie group dataclasses.
- Sets static dimensionality attributes
- Marks all methods for JIT compilation
"""
Expand Down
4 changes: 2 additions & 2 deletions scripts/se3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@
print(adjoint_T_w_b)

# Recover our twist, equivalent to `vee(logm(T_w_b.as_matrix()))`:
twist = T_w_b.log()
print(twist)
recovered_twist = T_w_b.log()
print(recovered_twist)
18 changes: 7 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="jaxlie",
version="1.2.6",
version="1.2.7",
description="Matrix Lie groups in Jax",
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -17,23 +17,19 @@
package_data={"jaxlie": ["py.typed"]},
python_requires=">=3.7",
install_requires=[
"flax",
"jax",
"jaxlib",
"numpy",
"jax>=0.2.20",
"jaxlib>=0.1.71",
"jax_dataclasses>=1.0.0",
# `overrides` should not be updated until the following issues are resolved:
# > https://github.com/mkorpela/overrides/issues/65
# > https://github.com/mkorpela/overrides/issues/63
# > https://github.com/mkorpela/overrides/issues/61
"numpy",
"overrides!=4",
],
extras_require={
"testing": [
"pytest",
"pytest-cov",
"flax",
"hypothesis",
"hypothesis[numpy]",
"pytest",
"pytest-cov",
]
},
classifiers=[
Expand Down
Loading

0 comments on commit 3e4b100

Please sign in to comment.