Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/jlog #22

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ def normalize(self) -> Self:
Normalized group member.
"""

@abc.abstractmethod
def jlog(self) -> jax.Array:
"""
Computes the Jacobian of the logarithm of the group element.

This is equivalent to the inverse of the right Jacobian.

Returns:
The Jacobian of the logarithm, having the dimensions `(tangent_dim, tangent_dim,)`.
"""

@classmethod
@abc.abstractmethod
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self:
Expand Down Expand Up @@ -244,7 +255,8 @@ def from_translation(cls, translation: hints.Array) -> Self:
# Extract rotation class from type parameter.
assert len(cls.__orig_bases__) == 1 # type: ignore
return cls.from_rotation_and_translation(
rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore
rotation=get_args(cls.__orig_bases__[0])[
0].identity(), # type: ignore
translation=translation,
)

Expand All @@ -268,7 +280,8 @@ def apply(self, target: hints.Array) -> jax.Array:
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
translation=(self.rotation() @ other.translation()) +
self.translation(),
)

@final
Expand Down
40 changes: 40 additions & 0 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,46 @@ def adjoint(self: "SE2") -> jax.Array:
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))

@override
def jlog(self) -> jax.Array:
# Reference:
# This is inverse of matrix (163) from Micro-Lie theory:
# > https://arxiv.org/pdf/1812.01537

log = self.log()
rho1, rho2, theta = log[..., 0], log[..., 1], log[..., 2]

# Safe division function
def safe_div(x, y): return jnp.where(jnp.abs(y) < 1e-10, 0., x / y)
simeon-ned marked this conversation as resolved.
Show resolved Hide resolved

sin_theta = jnp.sin(theta)
cos_theta = jnp.cos(theta)

# Common terms
denom = 2 - 2 * cos_theta
theta_sin_theta_term = safe_div(theta * sin_theta, denom)
one_minus_cos_term = cos_theta - 1

# Matrix elements
a11 = theta_sin_theta_term
a12 = -theta / 2
a21 = theta / 2
a22 = a11 # Same as a11

a13_num = (rho1 * theta * sin_theta / 2 + rho1 * cos_theta - rho1 +
rho2 * theta * cos_theta / 2 - rho2 * theta / 2)
a13 = safe_div(a13_num, theta * one_minus_cos_term)

a23_num = (-rho1 * theta * cos_theta / 2 + rho1 * theta / 2 +
rho2 * theta * sin_theta / 2 + rho2 * cos_theta - rho2)
a23 = safe_div(a23_num, theta * one_minus_cos_term)
jlog = jnp.array([
[a11, a12, a13],
[a21, a22, a23],
[0., 0., 1.]
])
return jlog

@classmethod
@override
def sample_uniform(
Expand Down
71 changes: 55 additions & 16 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,10 @@
from typing_extensions import override

from . import _base, hints
from ._so3 import SO3
from ._so3 import SO3, _skew
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group


def _skew(omega: hints.Array) -> jax.Array:
"""Returns the skew-symmetric form of a length-3 vector."""

wx, wy, wz = jnp.moveaxis(omega, -1, 0)
zeros = jnp.zeros_like(wx)
return jnp.stack(
[zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros],
axis=-1,
).reshape((*omega.shape[:-1], 3, 3))


@register_lie_group(
matrix_dim=4,
parameters_dim=7,
Expand Down Expand Up @@ -77,7 +66,8 @@ def translation(self) -> jax.Array:
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SE3:
return SE3(
wxyz_xyz=jnp.broadcast_to(
jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7)
jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
), (*batch_axes, 7)
)
)

Expand Down Expand Up @@ -131,7 +121,8 @@ def exp(cls, tangent: hints.Array) -> SE3:
jax.Array,
jnp.where(
use_taylor,
jnp.ones_like(theta_squared), # Any non-zero value should do here.
# Any non-zero value should do here.
jnp.ones_like(theta_squared),
theta_squared,
),
)
Expand All @@ -144,7 +135,8 @@ def exp(cls, tangent: hints.Array) -> SE3:
rotation.as_matrix(),
(
jnp.eye(3)
+ ((1.0 - jnp.cos(theta_safe)) / (theta_squared_safe))[..., None, None]
+ ((1.0 - jnp.cos(theta_safe)) /
(theta_squared_safe))[..., None, None]
* skew_omega
+ (
(theta_safe - jnp.sin(theta_safe))
Expand Down Expand Up @@ -210,7 +202,8 @@ def adjoint(self) -> jax.Array:
return jnp.concatenate(
[
jnp.concatenate(
[R, jnp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)],
[R, jnp.einsum("...ij,...jk->...ik",
_skew(self.translation()), R)],
axis=-1,
),
jnp.concatenate(
Expand All @@ -220,6 +213,52 @@ def adjoint(self) -> jax.Array:
axis=-2,
)

@override
def jlog(self) -> jax.Array:
# Reference:
# Equations (179a, 179b, 180) from Micro-Lie theory:
# > https://arxiv.org/pdf/1812.01537
# and the Jlog6 implementation in Pinocchio:
# > https://gepettoweb.laas.fr/doc/stack-of-tasks/pinocchio/master/doxygen-html/namespacepinocchio.html#a82e7cb47ae721d4161bbb143590096c5

rotation = self.rotation()
translation = self.translation()

jlog_so3 = rotation.jlog()

w = rotation.log()
theta = jnp.linalg.norm(w)

t2 = theta * theta
tinv = 1 / theta
t2inv = tinv * tinv
st, ct = jnp.sin(theta), jnp.cos(theta)
inv_2_2ct = 1 / (2 * (1 - ct))

beta = jnp.where(theta < 1e-6,
1 / 12 + t2 / 720,
t2inv - st * tinv * inv_2_2ct)

beta_dot_over_theta = jnp.where(theta < 1e-6,
1 / 360,
-2 * t2inv * t2inv + (1 + st * tinv) * t2inv * inv_2_2ct)
simeon-ned marked this conversation as resolved.
Show resolved Hide resolved

wTp = w @ translation
v3_tmp = (beta_dot_over_theta * wTp) * w - (theta**2 *
beta_dot_over_theta + 2 * beta) * translation
C = jnp.outer(v3_tmp, w) + beta * jnp.outer(w,
translation) + wTp * beta * jnp.eye(3)
C = C + 0.5 * _skew(translation)

B = C @ jlog_so3

jlog = jnp.zeros((6, 6))
jlog = jlog.at[:3, :3].set(jlog_so3)
jlog = jlog.at[3:, 3:].set(jlog_so3)
jlog = jlog.at[:3, 3:].set(B)

return jlog

@classmethod
@override
def sample_uniform(
Expand Down
10 changes: 10 additions & 0 deletions jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def normalize(self) -> SO2:
/ jnp.linalg.norm(self.unit_complex, axis=-1, keepdims=True)
)

@override
def jlog(self) -> jax.Array:
# Reference:
# For SO2 the jlog and right jacobians are trivially 1,
# equation (126) from Micro-Lie theory:
# > https://arxiv.org/pdf/1812.01537

return jnp.array([1])
simeon-ned marked this conversation as resolved.
Show resolved Hide resolved


@classmethod
@override
def sample_uniform(
Expand Down
34 changes: 32 additions & 2 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group


def _skew(omega: hints.Array) -> jax.Array:
"""Returns the skew-symmetric form of a length-3 vector."""

wx, wy, wz = jnp.moveaxis(omega, -1, 0)
zeros = jnp.zeros_like(wx)
return jnp.stack(
[zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros],
axis=-1,
).reshape((*omega.shape[:-1], 3, 3))


@register_lie_group(
matrix_dim=3,
parameters_dim=4,
Expand Down Expand Up @@ -164,7 +175,8 @@ def compute_yaw_radians(self) -> jax.Array:
@override
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO3:
return SO3(
wxyz=jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4))
wxyz=jnp.broadcast_to(
jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4))
)

@classmethod
Expand Down Expand Up @@ -342,7 +354,8 @@ def exp(cls, tangent: hints.Array) -> SO3:
safe_theta = jnp.sqrt(
jnp.where(
use_taylor,
jnp.ones_like(theta_squared), # Any constant value should do here.
# Any constant value should do here.
jnp.ones_like(theta_squared),
theta_squared,
)
)
Expand Down Expand Up @@ -418,6 +431,23 @@ def inverse(self) -> SO3:
def normalize(self) -> SO3:
return SO3(wxyz=self.wxyz / jnp.linalg.norm(self.wxyz, axis=-1, keepdims=True))

@override
def jlog(self) -> jax.Array:
# Reference:
# Equation (144) from Micro-Lie theory:
# > https://arxiv.org/pdf/1812.01537

log = self.log()
theta = jnp.linalg.norm(log)
st, ct = jnp.sin(theta), jnp.cos(theta)
factor1 = (theta * st) / (2 * (1 - ct))
simeon-ned marked this conversation as resolved.
Show resolved Hide resolved
factor2 = 1 / (theta ** 2) - st / (2 * theta * (1 - ct))

jlog = factor1 * jnp.eye(3)
jlog = jlog.at[:, :].add(0.5 * _skew(log))
jlog = jlog.at[:, :].add(factor2 * jnp.outer(log, log))
return jlog

@classmethod
@override
def sample_uniform(
Expand Down
Loading