From 727db09e5e3155e9077706ff4013a7d7765e91df Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 5 May 2024 23:37:53 -0700 Subject: [PATCH] Nits --- src/viser/transforms/_se2.py | 2 +- src/viser/transforms/_se3.py | 2 +- src/viser/transforms/_so2.py | 2 +- src/viser/transforms/_so3.py | 2 +- src/viser/transforms/utils/_utils.py | 10 ++++++---- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 45dfd7b77..ec83c7d7e 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -27,7 +27,7 @@ class SE2(_base.SEBase[SO2]): # SE2-specific. - unit_complex_xy: onp.ndarray + unit_complex_xy: onpt.NDArray[onp.floating] """Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 3)`.""" @override diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index 406c2a0ff..652bc7fec 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -40,7 +40,7 @@ class SE3(_base.SEBase[SO3]): # SE3-specific. - wxyz_xyz: onp.ndarray + wxyz_xyz: onpt.NDArray[onp.floating] """Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`.""" @override diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index b9189c2fe..c0104074e 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -27,7 +27,7 @@ class SO2(_base.SOBase): # SO2-specific. - unit_complex: onp.ndarray + unit_complex: onpt.NDArray[onp.floating] """Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`.""" @override diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index ac6397ceb..af0f0ace5 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -26,7 +26,7 @@ class SO3(_base.SOBase): `(omega_x, omega_y, omega_z)`. """ - wxyz: onp.ndarray + wxyz: onpt.NDArray[onp.floating] """Internal parameters. `(w, x, y, z)` quaternion. Shape should be `(*, 4)`.""" @override diff --git a/src/viser/transforms/utils/_utils.py b/src/viser/transforms/utils/_utils.py index d2a3ea380..765e51219 100644 --- a/src/viser/transforms/utils/_utils.py +++ b/src/viser/transforms/utils/_utils.py @@ -19,10 +19,12 @@ def get_epsilon(dtype: onp.dtype) -> float: Returns: Output float. """ - return { - onp.dtype("float32"): 1e-5, - onp.dtype("float64"): 1e-10, - }[dtype] + if dtype == onp.float32: + return 1e-5 + elif dtype == onp.float64: + return 1e-10 + else: + assert False def register_lie_group(