Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/jlog #22

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ def normalize(self) -> Self:
Normalized group member.
"""

@abc.abstractmethod
def jlog(self) -> jax.Array:
"""
Computes the Jacobian of the logarithm of the group element.

This is equivalent to the inverse of the right Jacobian.

Returns:
The Jacobian of the logarithm, having the dimensions `(tangent_dim, tangent_dim,)` or batch of these Jacobians.
"""

@classmethod
@abc.abstractmethod
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self:
Expand Down Expand Up @@ -244,7 +255,8 @@ def from_translation(cls, translation: hints.Array) -> Self:
# Extract rotation class from type parameter.
assert len(cls.__orig_bases__) == 1 # type: ignore
return cls.from_rotation_and_translation(
rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore
rotation=get_args(cls.__orig_bases__[0])[
0].identity(), # type: ignore
translation=translation,
)

Expand All @@ -268,7 +280,8 @@ def apply(self, target: hints.Array) -> jax.Array:
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
translation=(self.rotation() @ other.translation()) +
self.translation(),
)

@final
Expand Down
219 changes: 149 additions & 70 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,110 @@
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group


def _V(theta: jax.Array) -> jax.Array:
"""
Compute the V map for the given theta and rotation matrix.

This function calculates the V map, which is used in various geometric transformations.
It handles both small and large theta values using different computation methods.

Args:
theta (jax.Array): The input angle(s) in axis-angle representation.
rotation_matrix (jax.Array): The corresponding rotation matrix.

Returns:
jax.Array: A 3x3 matrix (or batch of 3x3 matrices) representing the V map.
"""
use_taylor = jnp.abs(theta) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
safe_theta = cast(
jax.Array,
jnp.where(
use_taylor,
jnp.ones_like(theta), # Any non-zero value should do here.
theta,
),
)

theta_sq = theta**2
sin_over_theta = cast(
jax.Array,
jnp.where(
use_taylor,
1.0 - theta_sq / 6.0,
jnp.sin(safe_theta) / safe_theta,
),
)
one_minus_cos_over_theta = cast(
jax.Array,
jnp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - jnp.cos(safe_theta)) / safe_theta,
),
)

V = jnp.stack(
[
sin_over_theta,
-one_minus_cos_over_theta,
one_minus_cos_over_theta,
sin_over_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))
return V


def _V_inv(theta: jax.Array) -> jax.Array:
"""
Compute the inverse of the V map for the given theta.

This function calculates the inverse of the V map, which is used in various
geometric transformations. It handles both small and large theta values
using different computation methods.

Args:
theta (jax.Array): The input angle(s) in axis-angle representation.

Returns:
jax.Array: A 3x3 matrix (or batch of 3x3 matrices) representing the inverse V map.
"""
cos = jnp.cos(theta)
cos_minus_one = cos - 1.0
half_theta = theta / 2.0
use_taylor = jnp.abs(cos_minus_one) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
safe_cos_minus_one = jnp.where(
use_taylor,
jnp.ones_like(cos_minus_one), # Any non-zero value should do here.
cos_minus_one,
)

half_theta_over_tan_half_theta = jnp.where(
use_taylor,
# Taylor approximation.
1.0 - theta**2 / 12.0,
# Default.
-(half_theta * jnp.sin(theta)) / safe_cos_minus_one,
)

V_inv = jnp.stack(
[
half_theta_over_tan_half_theta,
half_theta,
-half_theta,
half_theta_over_tan_half_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))
return V_inv


@register_lie_group(
matrix_dim=3,
parameters_dim=4,
Expand Down Expand Up @@ -130,46 +234,7 @@ def exp(cls, tangent: hints.Array) -> "SE2":
assert tangent.shape[-1:] == (3,)

theta = tangent[..., 2]
use_taylor = jnp.abs(theta) < get_epsilon(tangent.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
safe_theta = cast(
jax.Array,
jnp.where(
use_taylor,
jnp.ones_like(theta), # Any non-zero value should do here.
theta,
),
)

theta_sq = theta**2
sin_over_theta = cast(
jax.Array,
jnp.where(
use_taylor,
1.0 - theta_sq / 6.0,
jnp.sin(safe_theta) / safe_theta,
),
)
one_minus_cos_over_theta = cast(
jax.Array,
jnp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - jnp.cos(safe_theta)) / safe_theta,
),
)

V = jnp.stack(
[
sin_over_theta,
-one_minus_cos_over_theta,
one_minus_cos_over_theta,
sin_over_theta,
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))
V = _V(theta)
return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :2]),
Expand All @@ -184,36 +249,7 @@ def log(self) -> jax.Array:

theta = self.rotation().log()[..., 0]

cos = jnp.cos(theta)
cos_minus_one = cos - 1.0
half_theta = theta / 2.0
use_taylor = jnp.abs(cos_minus_one) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
safe_cos_minus_one = jnp.where(
use_taylor,
jnp.ones_like(cos_minus_one), # Any non-zero value should do here.
cos_minus_one,
)

half_theta_over_tan_half_theta = jnp.where(
use_taylor,
# Taylor approximation.
1.0 - theta**2 / 12.0,
# Default.
-(half_theta * jnp.sin(theta)) / safe_cos_minus_one,
)

V_inv = jnp.stack(
[
half_theta_over_tan_half_theta,
half_theta,
-half_theta,
half_theta_over_tan_half_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))
V_inv = _V_inv(theta)

tangent = jnp.concatenate(
[
Expand Down Expand Up @@ -242,6 +278,49 @@ def adjoint(self: "SE2") -> jax.Array:
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))

@override
def jlog(self) -> jax.Array:
# Reference:
# This is inverse of matrix (163) from Micro-Lie theory:
# > https://arxiv.org/pdf/1812.01537

log = self.log()
# rho1, rho2, theta =
rho1, rho2, theta = broadcast_leading_axes((log[..., 0], log[..., 1], log[..., 2]))

# Handle the case where theta is small to avoid division by zero
use_taylor = jnp.abs(theta) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for reverse-mode AD.
safe_theta = jnp.where(use_taylor, jnp.ones_like(theta), theta)

V_inv_theta = _V_inv(safe_theta)
V_inv_theta_T = jnp.swapaxes(V_inv_theta, -2, -1) # Transpose the last two dimensions

# Calculate r, handling the small theta case separately
batch_shape = theta.shape
eye_2 = jnp.eye(2).reshape((1,) * len(batch_shape) + (2, 2))

A = jnp.where(
use_taylor[..., None, None],
jnp.stack([
jnp.stack([theta/12., jnp.full_like(theta, 0.5)], axis=-1),
jnp.stack([jnp.full_like(theta, -0.5), theta/12.], axis=-1)
], axis=-2),
(eye_2 - V_inv_theta_T) / safe_theta[..., None, None]
)

rho = jnp.stack([rho1, rho2], axis=-1)[..., None]
r = jnp.squeeze(A @ rho, axis=-1)

# Create the jlog matrix
jlog = jnp.zeros((*batch_shape, 3, 3))
jlog = jlog.at[..., :2, :2].set(V_inv_theta_T)
jlog = jlog.at[..., :2, 2].set(r)
jlog = jlog.at[..., 2, 2].set(1) # Set the bottom right element to 1

return jlog

@classmethod
@override
def sample_uniform(
Expand Down
Loading