Skip to content

Commit

Permalink
getunit had an argument dim which was used both to check the dime…
Browse files Browse the repository at this point in the history
…nsionality of the argument, and also `dim=0` enforced a scalar rather than ndarray return value. Have now added a `vector=True` parameter which can be set to false to force a scalar return. All instances of passing `dim=0` have been found, documentation and tests updated.

Some black formatting changes are also included in here which is annoying.
  • Loading branch information
petercorke committed Jan 12, 2025
1 parent 6a9904f commit cfa84cc
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 150 deletions.
62 changes: 38 additions & 24 deletions spatialmath/base/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "array",
dtype: DTypeLike = np.float64,
) -> NDArray:
...
) -> NDArray: ...


@overload
Expand All @@ -293,8 +292,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "list",
dtype: DTypeLike = np.float64,
) -> List[float]:
...
) -> List[float]: ...


@overload
Expand All @@ -303,8 +301,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "sequence",
dtype: DTypeLike = np.float64,
) -> Tuple[float, ...]:
...
) -> Tuple[float, ...]: ...


@overload
Expand All @@ -313,8 +310,7 @@ def getvector(
dim: Optional[Union[int, None]] = None,
out: str = "sequence",
dtype: DTypeLike = np.float64,
) -> List[float]:
...
) -> List[float]: ...


def getvector(
Expand Down Expand Up @@ -522,16 +518,20 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
return False


def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
def getunit(
v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True
) -> Union[float, NDArray]:
"""
Convert values according to angular units
:param v: the value in radians or degrees
:type v: array_like(m)
:param unit: the angular unit, "rad" or "deg"
:type unit: str
:param dim: expected dimension of input, defaults to None
:param dim: expected dimension of input, defaults to don't check (None)
:type dim: int, optional
:param vector: return a scalar as a 1d vector, defaults to True
:type vector: bool, optional
:return: the converted value in radians
:rtype: ndarray(m) or float
:raises ValueError: argument is not a valid angular unit
Expand All @@ -543,30 +543,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
>>> from spatialmath.base import getunit
>>> import numpy as np
>>> getunit(1.5, 'rad')
>>> getunit(1.5, 'rad', dim=0)
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
>>> getunit(90, 'deg')
>>> getunit(90, 'deg', vector=False) # force a scalar output
>>> getunit(1.5, 'rad', dim=0) # check argument is scalar
>>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector
>>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector
>>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector
>>> getunit([90, 180], 'deg')
>>> getunit(np.r_[0.5, 1], 'rad')
>>> getunit(np.r_[90, 180], 'deg')
>>> getunit(np.r_[90, 180], 'deg', dim=2)
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
>>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector
>>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector
:note:
- the input value is processed by :func:`getvector` and the argument ``dim`` can
be used to check that ``v`` is the desired length.
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
be used to check that ``v`` is the desired length. Note that 0 means a scalar,
whereas 1 means a 1-element array.
- the output is always an ndarray except if the input is a scalar and ``vector=False``.
:seealso: :func:`getvector`
"""
if not isinstance(v, Iterable) and dim in (0, None):
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
if not isinstance(v, Iterable):
# scalar input
if dim is not None and dim != 0:
raise ValueError("for dim==0 input must be a scalar")
if vector:
# scalar in, vector out
if unit == "deg":
v = np.deg2rad(v)
elif unit != "rad":
raise ValueError("invalid angular units")
return np.array([v])
else:
raise ValueError("invalid angular units")
# scalar in, scalar out
if unit == "rad":
return v
elif unit == "deg":
return np.deg2rad(v)
else:
raise ValueError("invalid angular units")

else:
# scalar or iterable in, ndarray out
Expand Down
52 changes: 23 additions & 29 deletions spatialmath/base/transforms2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
>>> rot2(0.3)
>>> rot2(45, 'deg')
"""
theta = smb.getunit(theta, unit, dim=0)
theta = smb.getunit(theta, unit, vector=False)
ct = smb.sym.cos(theta)
st = smb.sym.sin(theta)
# fmt: off
Expand Down Expand Up @@ -172,18 +172,15 @@ def tr2xyt(T: SE2Array, unit: str = "rad") -> R3:

# ---------------------------------------------------------------------------------------#
@overload # pragma: no cover
def transl2(x: float, y: float) -> SE2Array:
...
def transl2(x: float, y: float) -> SE2Array: ...


@overload # pragma: no cover
def transl2(x: ArrayLike2) -> SE2Array:
...
def transl2(x: ArrayLike2) -> SE2Array: ...


@overload # pragma: no cover
def transl2(x: SE2Array) -> R2:
...
def transl2(x: SE2Array) -> R2: ...


def transl2(x, y=None):
Expand Down Expand Up @@ -446,8 +443,7 @@ def trlog2(
twist: bool = False,
check: bool = True,
tol: float = 20,
) -> so2Array:
...
) -> so2Array: ...


@overload # pragma: no cover
Expand All @@ -456,8 +452,7 @@ def trlog2(
twist: bool = False,
check: bool = True,
tol: float = 20,
) -> se2Array:
...
) -> se2Array: ...


@overload # pragma: no cover
Expand All @@ -466,8 +461,7 @@ def trlog2(
twist: bool = True,
check: bool = True,
tol: float = 20,
) -> float:
...
) -> float: ...


@overload # pragma: no cover
Expand All @@ -476,8 +470,7 @@ def trlog2(
twist: bool = True,
check: bool = True,
tol: float = 20,
) -> R3:
...
) -> R3: ...


def trlog2(
Expand Down Expand Up @@ -563,13 +556,15 @@ def trlog2(

# ---------------------------------------------------------------------------------------#
@overload # pragma: no cover
def trexp2(S: so2Array, theta: Optional[float] = None, check: bool = True) -> SO2Array:
...
def trexp2(
S: so2Array, theta: Optional[float] = None, check: bool = True
) -> SO2Array: ...


@overload # pragma: no cover
def trexp2(S: se2Array, theta: Optional[float] = None, check: bool = True) -> SE2Array:
...
def trexp2(
S: se2Array, theta: Optional[float] = None, check: bool = True
) -> SE2Array: ...


def trexp2(
Expand Down Expand Up @@ -692,8 +687,7 @@ def trexp2(


@overload # pragma: no cover
def trnorm2(R: SO2Array) -> SO2Array:
...
def trnorm2(R: SO2Array) -> SO2Array: ...


def trnorm2(T: SE2Array) -> SE2Array:
Expand Down Expand Up @@ -758,13 +752,11 @@ def trnorm2(T: SE2Array) -> SE2Array:


@overload # pragma: no cover
def tradjoint2(T: SO2Array) -> R1x1:
...
def tradjoint2(T: SO2Array) -> R1x1: ...


@overload # pragma: no cover
def tradjoint2(T: SE2Array) -> R3x3:
...
def tradjoint2(T: SE2Array) -> R3x3: ...


def tradjoint2(T):
Expand Down Expand Up @@ -853,13 +845,15 @@ def tr2jac2(T: SE2Array) -> R3x3:


@overload
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True) -> SO2Array:
...
def trinterp2(
start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True
) -> SO2Array: ...


@overload
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True) -> SE2Array:
...
def trinterp2(
start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True
) -> SE2Array: ...


def trinterp2(start, end, s, shortest: bool = True):
Expand Down
Loading

0 comments on commit cfa84cc

Please sign in to comment.