Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/jlog #22

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft

Feature/jlog #22

wants to merge 9 commits into from

Conversation

simeon-ned
Copy link

Reduce the size of following text, specifically do not use separate bullet for groups, just make list of them:"Thank you for providing that information. Based on the content you shared, I can summarize the key points:

This pull request adds analytical Jacobians for the logarithm (log) operation for all groups presented in jaxlie. These Jacobians are also known as right inverse Jacobians. They represent the partial derivative of the log operation with respect to the group element:

$$\text{Jlog}(T) = \frac{\partial \log_3(T)}{\partial T}$$

The implementation are based on derivations presented in the "micro Lie theory" paper, specifically equations 41c, 79, 126, 144, 163, and 179. Similar functionality is also implemented in established robotics libraries like Pinocchio, indicating the practical importance of these Jacobians in robotics applications.

Copy link
Owner

@brentyi brentyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @simeon-ned! This is a great first pass.

I left some comments. It seems like the main things that need addressing are:

  • Handling leading batch axes.
  • Making sure we handle all divide by zero cases, including preventing revese-mode AD-specific NaNs1 for the jlog methods.

For tests:

  • Move tests to pytest.
  • Tests for batch axes.
  • Tests for autodiff through jlog. (perhaps using the jacnumerical or _assert_jacobians_close helpers, see test_autodiff.py)

Does that sound right? Am I missing anything?

Footnotes

  1. https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

jaxlie/_se2.py Outdated Show resolved Hide resolved
jaxlie/_so2.py Outdated Show resolved Hide resolved
jaxlie/_so3.py Outdated Show resolved Hide resolved
tests/test_jlog.py Outdated Show resolved Hide resolved
tests/test_jlog.py Outdated Show resolved Hide resolved
jaxlie/_se3.py Outdated Show resolved Hide resolved
@brentyi
Copy link
Owner

brentyi commented Jul 8, 2024

Actually, for SE(2) and SE(3), are you able to reuse/refactor the code for the "V" and "V_inv" matrices?

jaxlie/jaxlie/_se2.py

Lines 164 to 175 in 84babf5

V = jnp.stack(
[
sin_over_theta,
-one_minus_cos_over_theta,
one_minus_cos_over_theta,
sin_over_theta,
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))
return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :2]),

jaxlie/jaxlie/_se3.py

Lines 142 to 159 in 84babf5

V = jnp.where(
use_taylor[..., None, None],
rotation.as_matrix(),
(
jnp.eye(3)
+ ((1.0 - jnp.cos(theta_safe)) / (theta_squared_safe))[..., None, None]
* skew_omega
+ (
(theta_safe - jnp.sin(theta_safe))
/ (theta_squared_safe * theta_safe)
)[..., None, None]
* jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return SE3.from_rotation_and_translation(
rotation=rotation,
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :3]),

jaxlie/jaxlie/_se2.py

Lines 208 to 221 in 84babf5

V_inv = jnp.stack(
[
half_theta_over_tan_half_theta,
half_theta,
-half_theta,
half_theta_over_tan_half_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))
tangent = jnp.concatenate(
[
jnp.einsum("...ij,...j->...i", V_inv, self.translation()),
theta[..., None],

jaxlie/jaxlie/_se3.py

Lines 183 to 205 in 84babf5

V_inv = jnp.where(
use_taylor[..., None, None],
jnp.eye(3)
- 0.5 * skew_omega
+ jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0,
(
jnp.eye(3)
- 0.5 * skew_omega
+ (
(
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
)[..., None, None]
* jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return jnp.concatenate(
[jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1
)

These already seem to describe the parts of the left/right Jacobians that are the trickiest to compute.

@simeon-ned
Copy link
Author

Actually, for SE(2) and SE(3), are you able to reuse/refactor the code for the "V" and "V_inv" matrices?

jaxlie/jaxlie/_se2.py

Lines 164 to 175 in 84babf5

V = jnp.stack(
[
sin_over_theta,
-one_minus_cos_over_theta,
one_minus_cos_over_theta,
sin_over_theta,
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))
return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :2]),

jaxlie/jaxlie/_se3.py

Lines 142 to 159 in 84babf5

V = jnp.where(
use_taylor[..., None, None],
rotation.as_matrix(),
(
jnp.eye(3)
+ ((1.0 - jnp.cos(theta_safe)) / (theta_squared_safe))[..., None, None]
* skew_omega
+ (
(theta_safe - jnp.sin(theta_safe))
/ (theta_squared_safe * theta_safe)
)[..., None, None]
* jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return SE3.from_rotation_and_translation(
rotation=rotation,
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :3]),

jaxlie/jaxlie/_se2.py

Lines 208 to 221 in 84babf5

V_inv = jnp.stack(
[
half_theta_over_tan_half_theta,
half_theta,
-half_theta,
half_theta_over_tan_half_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))
tangent = jnp.concatenate(
[
jnp.einsum("...ij,...j->...i", V_inv, self.translation()),
theta[..., None],

jaxlie/jaxlie/_se3.py

Lines 183 to 205 in 84babf5

V_inv = jnp.where(
use_taylor[..., None, None],
jnp.eye(3)
- 0.5 * skew_omega
+ jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0,
(
jnp.eye(3)
- 0.5 * skew_omega
+ (
(
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
)[..., None, None]
* jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return jnp.concatenate(
[jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1
)

These already seem to describe the parts of the left/right Jacobians that are the trickiest to compute.

Great suggestion, should I encupsulate V, V_inv in to separate method or create standalone function similar to _skew in SE3?

@brentyi
Copy link
Owner

brentyi commented Jul 8, 2024

I'd vote for a helper function for now so it's not exposed in the API, it's easier to add things to the API later than to remove them!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants