Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for issue #144. #155

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions spatialmath/base/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,16 +522,20 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
return False


def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
def getunit(
v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True
) -> Union[float, NDArray]:
"""
Convert values according to angular units

:param v: the value in radians or degrees
:type v: array_like(m)
:param unit: the angular unit, "rad" or "deg"
:type unit: str
:param dim: expected dimension of input, defaults to None
:param dim: expected dimension of input, defaults to don't check (None)
:type dim: int, optional
:param vector: return a scalar as a 1d vector, defaults to True
:type vector: bool, optional
:return: the converted value in radians
:rtype: ndarray(m) or float
:raises ValueError: argument is not a valid angular unit
Expand All @@ -543,30 +547,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
>>> from spatialmath.base import getunit
>>> import numpy as np
>>> getunit(1.5, 'rad')
>>> getunit(1.5, 'rad', dim=0)
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
>>> getunit(90, 'deg')
>>> getunit(90, 'deg', vector=False) # force a scalar output
>>> getunit(1.5, 'rad', dim=0) # check argument is scalar
>>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector
>>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector
>>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector
>>> getunit([90, 180], 'deg')
>>> getunit(np.r_[0.5, 1], 'rad')
>>> getunit(np.r_[90, 180], 'deg')
>>> getunit(np.r_[90, 180], 'deg', dim=2)
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
>>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector
>>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector

:note:
- the input value is processed by :func:`getvector` and the argument ``dim`` can
be used to check that ``v`` is the desired length.
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
be used to check that ``v`` is the desired length. Note that 0 means a scalar,
whereas 1 means a 1-element array.
- the output is always an ndarray except if the input is a scalar and ``vector=False``.

:seealso: :func:`getvector`
"""
if not isinstance(v, Iterable) and dim == 0:
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
if not isinstance(v, Iterable):
# scalar input
if dim is not None and dim != 0:
raise ValueError("for dim==0 input must be a scalar")
if vector:
# scalar in, vector out
if unit == "deg":
v = np.deg2rad(v)
elif unit != "rad":
raise ValueError("invalid angular units")
return np.array([v])
else:
raise ValueError("invalid angular units")
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
else:
raise ValueError("invalid angular units")

else:
# scalar or iterable in, ndarray out
Expand Down
2 changes: 1 addition & 1 deletion spatialmath/base/transforms2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
>>> rot2(0.3)
>>> rot2(45, 'deg')
"""
theta = smb.getunit(theta, unit, dim=0)
theta = smb.getunit(theta, unit, vector=False)
ct = smb.sym.cos(theta)
st = smb.sym.sin(theta)
# fmt: off
Expand Down
35 changes: 21 additions & 14 deletions spatialmath/base/transforms3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array:
:SymPy: supported
"""

theta = getunit(theta, unit, dim=0)
theta = getunit(theta, unit, vector=False)
ct = sym.cos(theta)
st = sym.sin(theta)
# fmt: off
Expand Down Expand Up @@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array:
:SymPy: supported
"""

theta = getunit(theta, unit, dim=0)
theta = getunit(theta, unit, vector=False)
ct = sym.cos(theta)
st = sym.sin(theta)
# fmt: off
Expand Down Expand Up @@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array:
:seealso: :func:`~trotz`
:SymPy: supported
"""
theta = getunit(theta, unit, dim=0)
theta = getunit(theta, unit, vector=False)
ct = sym.cos(theta)
st = sym.sin(theta)
# fmt: off
Expand Down Expand Up @@ -2709,7 +2709,7 @@ def tr2adjoint(T):

:Reference:
- Robotics, Vision & Control for Python, Section 3, P. Corke, Springer 2023.
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>_
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>`_

:SymPy: supported
"""
Expand Down Expand Up @@ -3002,29 +3002,36 @@ def trplot(
- ``width`` of line
- ``length`` of line
- ``style`` which is one of:

- ``'arrow'`` [default], draw line with arrow head in ``color``
- ``'line'``, draw line with no arrow head in ``color``
- ``'rgb'``, frame axes are lines with no arrow head and red for X, green
for Y, blue for Z; no origin dot
for Y, blue for Z; no origin dot
- ``'rviz'``, frame axes are thick lines with no arrow head and red for X,
green for Y, blue for Z; no origin dot
green for Y, blue for Z; no origin dot

- coordinate axis labels depend on:

- ``axislabel`` if True [default] label the axis, default labels are X, Y, Z
- ``labels`` 3-list of alternative axis labels
- ``textcolor`` which defaults to ``color``
- ``axissubscript`` if True [default] add the frame label ``frame`` as a subscript
for each axis label
for each axis label

- coordinate frame label depends on:

- `frame` the label placed inside {} near the origin of the frame

- a dot at the origin

- ``originsize`` size of the dot, if zero no dot
- ``origincolor`` color of the dot, defaults to ``color``

Examples::

trplot(T, frame='A')
trplot(T, frame='A', color='green')
trplot(T1, 'labels', 'UVW');
trplot(T, frame='A')
trplot(T, frame='A', color='green')
trplot(T1, 'labels', 'UVW');

.. plot::

Expand Down Expand Up @@ -3383,12 +3390,12 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str:
:param **kwargs: arguments passed to ``trplot``

- ``tranimate(T)`` where ``T`` is an SO(3) or SE(3) matrix, animates a 3D
coordinate frame moving from the world frame to the frame ``T`` in
``nsteps``.
coordinate frame moving from the world frame to the frame ``T`` in
``nsteps``.

- ``tranimate(I)`` where ``I`` is an iterable or generator, animates a 3D
coordinate frame representing the pose of each element in the sequence of
SO(3) or SE(3) matrices.
coordinate frame representing the pose of each element in the sequence of
SO(3) or SE(3) matrices.

Examples:

Expand Down
18 changes: 12 additions & 6 deletions spatialmath/base/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,14 +530,15 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:
:param theta: input angle
:type theta: scalar or ndarray
:return: angle wrapped into range :math:`[0, \pi)`
:rtype: scalar or ndarray

This is used to fold angles of colatitude. If zero is the angle of the
north pole, colatitude increases to :math:`\pi` at the south pole then
decreases to :math:`0` as we head back to the north pole.

:seealso: :func:`wrap_mpi2_pi2` :func:`wrap_0_2pi` :func:`wrap_mpi_pi` :func:`angle_wrap`
"""
theta = np.abs(theta)
theta = np.abs(getvector(theta))
n = theta / np.pi
if isinstance(n, np.ndarray):
n = n.astype(int)
Expand All @@ -546,7 +547,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:

y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, (n + 1) * np.pi - theta)
if isinstance(y, np.ndarray) and y.size == 1:
return float(y)
return float(y[0])
else:
return y

Expand All @@ -558,6 +559,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:
:param theta: input angle
:type theta: scalar or ndarray
:return: angle wrapped into range :math:`[-\pi/2, \pi/2]`
:rtype: scalar or ndarray

This is used to fold angles of latitude.

Expand All @@ -573,7 +575,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:

y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta)
if isinstance(y, np.ndarray) and len(y) == 1:
return float(y)
return float(y[0])
else:
return y

Expand All @@ -585,13 +587,14 @@ def wrap_0_2pi(theta: ArrayLike) -> Union[float, NDArray]:
:param theta: input angle
:type theta: scalar or ndarray
:return: angle wrapped into range :math:`[0, 2\pi)`
:rtype: scalar or ndarray

:seealso: :func:`wrap_mpi_pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
"""
theta = getvector(theta)
y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi)
if isinstance(y, np.ndarray) and len(y) == 1:
return float(y)
return float(y[0])
else:
return y

Expand All @@ -603,13 +606,14 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]:
:param theta: input angle
:type theta: scalar or ndarray
:return: angle wrapped into range :math:`[-\pi, \pi)`
:rtype: scalar or ndarray

:seealso: :func:`wrap_0_2pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
"""
theta = getvector(theta)
y = np.mod(theta + math.pi, 2 * math.pi) - np.pi
if isinstance(y, np.ndarray) and len(y) == 1:
return float(y)
return float(y[0])
else:
return y

Expand Down Expand Up @@ -643,6 +647,7 @@ def angdiff(a, b=None):
- ``angdiff(a, b)`` is the difference ``a - b`` wrapped to the range
:math:`[-\pi, \pi)`. This is the operator :math:`a \circleddash b` used
in the RVC book

- If ``a`` and ``b`` are both scalars, the result is scalar
- If ``a`` is array_like, the result is a NumPy array ``a[i]-b``
- If ``a`` is array_like, the result is a NumPy array ``a-b[i]``
Expand All @@ -651,6 +656,7 @@ def angdiff(a, b=None):

- ``angdiff(a)`` is the angle or vector of angles ``a`` wrapped to the range
:math:`[-\pi, \pi)`.

- If ``a`` is a scalar, the result is scalar
- If ``a`` is array_like, the result is a NumPy array

Expand All @@ -671,7 +677,7 @@ def angdiff(a, b=None):

y = np.mod(a + math.pi, 2 * math.pi) - math.pi
if isinstance(y, np.ndarray) and len(y) == 1:
return float(y)
return float(y[0])
else:
return y

Expand Down
2 changes: 1 addition & 1 deletion spatialmath/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def AngVec(
:seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r`
"""
v = smb.getvector(v, 3)
theta = smb.getunit(theta, unit, dim=0)
theta = smb.getunit(theta, unit, vector=False)
return cls(
s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False
)
Expand Down
42 changes: 40 additions & 2 deletions tests/base/test_argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,49 @@ def test_verifymatrix(self):
verifymatrix(a, (3, 4))

def test_unit(self):
self.assertIsInstance(getunit(1), np.ndarray)
# scalar -> vector
self.assertEqual(getunit(1), np.array([1]))
self.assertEqual(getunit(1, dim=0), np.array([1]))
with self.assertRaises(ValueError):
self.assertEqual(getunit(1, dim=1), np.array([1]))

self.assertEqual(getunit(1, unit="deg"), np.array([1 * math.pi / 180.0]))
self.assertEqual(getunit(1, dim=0, unit="deg"), np.array([1 * math.pi / 180.0]))
with self.assertRaises(ValueError):
self.assertEqual(
getunit(1, dim=1, unit="deg"), np.array([1 * math.pi / 180.0])
)

# scalar -> scalar
self.assertEqual(getunit(1, vector=False), 1)
self.assertEqual(getunit(1, dim=0, vector=False), 1)
with self.assertRaises(ValueError):
self.assertEqual(getunit(1, dim=1, vector=False), 1)

self.assertIsInstance(getunit(1.0, vector=False), float)
self.assertIsInstance(getunit(1, vector=False), int)

self.assertEqual(getunit(1, vector=False, unit="deg"), 1 * math.pi / 180.0)
self.assertEqual(
getunit(1, dim=0, vector=False, unit="deg"), 1 * math.pi / 180.0
)
with self.assertRaises(ValueError):
self.assertEqual(
getunit(1, dim=1, vector=False, unit="deg"), 1 * math.pi / 180.0
)

self.assertIsInstance(getunit(1.0, vector=False, unit="deg"), float)
self.assertIsInstance(getunit(1, vector=False, unit="deg"), float)

# vector -> vector
self.assertEqual(getunit([1]), np.array([1]))
self.assertEqual(getunit([1], dim=1), np.array([1]))
with self.assertRaises(ValueError):
getunit([1], dim=0)

self.assertIsInstance(getunit([1, 2]), np.ndarray)
self.assertIsInstance(getunit((1, 2)), np.ndarray)
self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray)
self.assertIsInstance(getunit(1.0, dim=0), float)

nt.assert_equal(getunit(5, "rad"), 5)
nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0)
Expand Down
Loading