diff --git a/brainunit/_base.py b/brainunit/_base.py index 95df305..87e70de 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -36,7 +36,7 @@ 'DIMENSIONLESS', 'DimensionMismatchError', 'get_or_create_dimension', - 'get_unit', + 'get_dim', 'get_basic_unit', 'is_unitless', 'have_same_unit', @@ -119,7 +119,7 @@ def get_unit_for_display(d): if (isinstance(d, int) and d == 1) or d is DIMENSIONLESS: return "1" else: - return str(get_unit(d)) + return str(get_dim(d)) # SI dimensions (see table at the top of the file) and various descriptions, @@ -497,7 +497,7 @@ def __str__(self): return s -def get_unit(obj) -> Dimension: +def get_dim(obj) -> Dimension: """ Return the unit of any object that has them. @@ -551,8 +551,8 @@ def have_same_unit(obj1, obj2) -> bool: # should only add a small amount of unnecessary computation for cases in # which this function returns False which very likely leads to a # DimensionMismatchError anyway. - dim1 = get_unit(obj1) - dim2 = get_unit(obj2) + dim1 = get_dim(obj1) + dim2 = get_dim(obj2) return (dim1 is dim2) or (dim1 == dim2) or dim1 is None or dim2 is None @@ -598,11 +598,11 @@ def fail_for_dimension_mismatch( if not _unit_checking: return None, None - dim1 = get_unit(obj1) + dim1 = get_dim(obj1) if obj2 is None: dim2 = DIMENSIONLESS else: - dim2 = get_unit(obj2) + dim2 = get_dim(obj2) if dim1 is not dim2 and not (dim1 is None or dim2 is None): # Special treatment for "0": @@ -779,7 +779,7 @@ def is_unitless(obj) -> bool: dimensionless : `bool` ``True`` if `obj` is dimensionless. """ - return get_unit(obj) is DIMENSIONLESS + return get_dim(obj) is DIMENSIONLESS def is_scalar_type(obj) -> bool: @@ -1105,8 +1105,8 @@ def has_same_unit(self, other): """ if not _unit_checking: return True - other_unit = get_unit(other.dim) - return (get_unit(self.dim) is other_unit) or (get_unit(self.dim) == other_unit) + other_unit = get_dim(other.dim) + return (get_dim(self.dim) is other_unit) or (get_dim(self.dim) == other_unit) def get_best_unit(self, *regs) -> 'Quantity': """ @@ -1475,7 +1475,7 @@ def _binary_operation( _, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other) if other_dim is None: - other_dim = get_unit(other) + other_dim = get_dim(other) new_dim = unit_operation(self.dim, other_dim) result = value_operation(self.value, other.value) @@ -1940,9 +1940,21 @@ def split(self, indices_or_sections, axis=0) -> List['Quantity']: """ return [Quantity(a, dim=self.dim) for a in jnp.split(self.value, indices_or_sections, axis=axis)] - def take(self, indices, axis=None, mode=None) -> 'Quantity': + def take( + self, + indices, + axis=None, + mode=None, + unique_indices=False, + indices_are_sorted=False, + fill_value=None, + ) -> 'Quantity': """Return an array formed from the elements of a at the given indices.""" - return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode), dim=self.dim) + if isinstance(fill_value, Quantity): + fill_value = fill_value.value + return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode, + unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, + fill_value=fill_value), dim=self.dim) def tolist(self): """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. @@ -3021,8 +3033,8 @@ def new_f(*args, **kwds): ) raise TypeError(error_message) if not have_same_unit(newkeyset[k], newkeyset[au[k]]): - d1 = get_unit(newkeyset[k]) - d2 = get_unit(newkeyset[au[k]]) + d1 = get_dim(newkeyset[k]) + d2 = get_dim(newkeyset[au[k]]) error_message = ( f"Function '{f.__name__}' expected " f"the argument '{k}' to have the same " @@ -3043,13 +3055,13 @@ def new_f(*args, **kwds): f"'{value}'" ) raise DimensionMismatchError( - error_message, get_unit(newkeyset[k]) + error_message, get_dim(newkeyset[k]) ) result = f(*args, **kwds) if "result" in au: if isinstance(au["result"], Callable) and au["result"] != bool: - expected_result = au["result"](*[get_unit(a) for a in args]) + expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] if au["result"] == bool: @@ -3069,7 +3081,7 @@ def new_f(*args, **kwds): f"unit {unit} but was " f"'{result}'" ) - raise DimensionMismatchError(error_message, get_unit(result)) + raise DimensionMismatchError(error_message, get_dim(result)) return result new_f._orig_func = f diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index fe5c2d7..5380f68 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -36,7 +36,7 @@ check_units, fail_for_dimension_mismatch, get_or_create_dimension, - get_unit, + get_dim, get_basic_unit, have_same_unit, in_unit, @@ -74,7 +74,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" + assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_dim(q)}) ({get_dim(unit)})" if not jnp.allclose(q.value, values): raise AssertionError(f"Values do not match: {q.value} != {values}") elif isinstance(q, jnp.ndarray): @@ -145,19 +145,19 @@ def test_get_dimensions(): Test various ways of getting/comparing the dimensions of a Array. """ q = 500 * ms - assert get_unit(q) is get_or_create_dimension(q.dim._dims) - assert get_unit(q) is q.dim + assert get_dim(q) is get_or_create_dimension(q.dim._dims) + assert get_dim(q) is q.dim assert q.has_same_unit(3 * second) dims = q.dim assert_equal(dims.get_dimension("time"), 1.0) assert_equal(dims.get_dimension("length"), 0) - assert get_unit(5) is DIMENSIONLESS - assert get_unit(5.0) is DIMENSIONLESS - assert get_unit(np.array(5, dtype=np.int32)) is DIMENSIONLESS - assert get_unit(np.array(5.0)) is DIMENSIONLESS - assert get_unit(np.float32(5.0)) is DIMENSIONLESS - assert get_unit(np.float64(5.0)) is DIMENSIONLESS + assert get_dim(5) is DIMENSIONLESS + assert get_dim(5.0) is DIMENSIONLESS + assert get_dim(np.array(5, dtype=np.int32)) is DIMENSIONLESS + assert get_dim(np.array(5.0)) is DIMENSIONLESS + assert get_dim(np.float32(5.0)) is DIMENSIONLESS + assert get_dim(np.float64(5.0)) is DIMENSIONLESS assert is_scalar_type(5) assert is_scalar_type(5.0) assert is_scalar_type(np.array(5, dtype=np.int32)) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 2476adc..be20b91 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -21,7 +21,6 @@ import numpy as np from jax import Array -from brainunit._misc import set_module_as from .._base import ( DIMENSIONLESS, Quantity, @@ -29,6 +28,7 @@ fail_for_dimension_mismatch, is_unitless, ) +from .._misc import set_module_as __all__ = [ # array creation @@ -46,17 +46,23 @@ def full( dtype: Optional[Any] = None, ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape`, filled with `fill_value` if `fill_value` is a Quantity. + Returns a quantity of `shape`, filled with `fill_value` if `fill_value` is a Quantity. else return an array of `shape` filled with `fill_value`. - Args: - shape: sequence of integers, describing the shape of the output array. - fill_value: the value to fill the new array with. - dtype: the type of the output array, or `None`. If not `None`, `fill_value` - will be cast to `dtype`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Parameters + ---------- + shape : int or sequence of ints + Shape of the new array, e.g., ``(2, 3)`` or ``2``. + fill_value : scalar, array_like or Quantity + Fill value. + dtype : data-type, optional + The desired data-type for the array The default, None, means ``np.array(fill_value).dtype` + + Returns + ------- + out : quantity or ndarray + Quantity with the given shape if `fill_value` is a Quantity, else an array. + Array of `fill_value` with the given shape, dtype, and order. """ if isinstance(fill_value, Quantity): return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim) @@ -69,23 +75,31 @@ def eye( M: Optional[int] = None, k: int = 0, dtype: Optional[Any] = None, - unit: Optional[Unit] = None + unit: Optional[Unit] = None, ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. - else return an identity matrix of `shape`. - - Args: - n: the number of rows (and columns) in the output array. - k: the index of the diagonal: 0 (the default) refers to the main diagonal, - a positive value refers to an upper diagonal, and a negative value to a - lower diagonal. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns a 2-D quantity or array of `shape` and `unit` with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + N : int + Number of rows in the output. + M : int, optional + Number of columns in the output. If None, defaults to `N`. + k : int, optional + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + dtype : data-type, optional + Data-type of the returned array. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + I : quantity or ndarray of shape (N,M) + An array where all elements are equal to zero, except for the `k`-th + diagonal, whose values are equal to one. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -101,17 +115,25 @@ def identity( unit: Optional[Unit] = None ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. - else return an identity matrix of `shape`. + Return the identity Quantity or array. - Args: - n: the number of rows (and columns) in the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - unit: the unit of the output array, or `None`. + The identity array is a square array with ones on + the main diagonal. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Parameters + ---------- + n : int + Number of rows (and columns) in `n` x `n` output. + dtype : data-type, optional + Data-type of the output. Defaults to ``float``. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + out : quantity or ndarray + `n` x `n` quantity or array with its main diagonal set to one, + and all other elements 0. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -129,21 +151,29 @@ def tri( unit: Optional[Unit] = None ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. - else return a triangular matrix of `shape`. - - Args: - n: the number of rows in the output array. - m: the number of columns with default being `n`. - k: the index of the diagonal: 0 (the default) refers to the main diagonal, - a positive value refers to an upper diagonal, and a negative value to a - lower diagonal. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + A quantity or an array with ones at and below the given diagonal and zeros elsewhere. + + Parameters + ---------- + N : int + Number of rows in the array. + M : int, optional + Number of columns in the array. + By default, `M` is taken equal to `N`. + k : int, optional + The sub-diagonal at and below which the array is filled. + `k` = 0 is the main diagonal, while `k` < 0 is below it, + and `k` > 0 is above. The default is 0. + dtype : dtype, optional + Data type of the returned array. The default is float. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + tri : quantity or ndarray of shape (N, M) + quantity or array with its lower triangle filled with ones and zero elsewhere; + in other words ``T[i,j] == 1`` for ``j <= i + k``, 0 otherwise. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -159,17 +189,21 @@ def empty( unit: Optional[Unit] = None ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. - else return an array of `shape` with uninitialized values. + Return a new quantity or array of given shape and type, without initializing entries. - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be of type `dtype`. - unit: the unit of the output array, or `None`. + Parameters + ---------- + shape : sequence of int + Shape of the empty quantity or array. + dtype : data-type, optional + Data-type of the output. Defaults to ``float``. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + quantity or array of uninitialized (arbitrary) data of the given shape, dtype, and order. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -185,17 +219,21 @@ def ones( unit: Optional[Unit] = None ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. - else return an array of `shape` filled with 1. + Returns a new quantity or array of given shape and type, filled with ones. - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - unit: the unit of the output array, or `None`. + Parameters + ---------- + shape : sequence of int + Shape of the new quantity or array. + dtype : data-type, optional + The desired data-type for the array. Default is `float`. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + Array of ones with the given shape, dtype, and order. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -211,17 +249,21 @@ def zeros( unit: Optional[Unit] = None ) -> Union[Array, Quantity]: """ - Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. - else return an array of `shape` filled with 0. + Returns a new quantity or array of given shape and type, filled with zeros. - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - unit: the unit of the output array, or `None`. + Parameters + ---------- + shape : sequence of int + Shape of the new quantity or array. + dtype : data-type, optional + The desired data-type for the array. Default is `float`. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + Array of zeros with the given shape, dtype, and order. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -238,24 +280,32 @@ def full_like( shape: Any = None ) -> Union[Quantity, jax.Array]: """ - Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity. - else return an array of `a` filled with `fill_value`. + Return a new quantity or array with the same shape and type as a given array or quantity, filled with `fill_value`. - Args: - a: array_like, Quantity, shape, or dtype - fill_value: scalar or array_like - dtype: data-type, optional - shape: sequence of ints, optional + Parameters + ---------- + a : quantity or ndarray + The shape and data-type of `a` define these same attributes of the returned quantity or array. + fill_value : quantity or ndarray + Value to fill the new quantity or array with. + dtype : data-type, optional + Overrides the data type of the result. + shape : sequence of int, optional + Overrides the shape of the result. If `shape` is not given, the shape of `a` is used. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + New quantity or array with the same shape and type as `a`, filled with `fill_value`. """ if isinstance(fill_value, Quantity): if isinstance(a, Quantity): fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), + dim=a.dim) else: - return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), dim=fill_value.dim) + return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), + dim=fill_value.dim) else: if isinstance(a, Quantity): return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) @@ -264,145 +314,186 @@ def full_like( @set_module_as('brainunit.math') -def diag(a: Union[Quantity, jax.typing.ArrayLike], - k: int = 0, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: +def diag( + v: Union[Quantity, jax.typing.ArrayLike], + k: int = 0, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ Extract a diagonal or construct a diagonal array. - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - """ - if isinstance(a, Quantity): + Parameters + ---------- + v : quantity or ndarray + If `a` is a 1-D array, `diag` constructs a 2-D array with `v` on the `k`-th diagonal. + If `a` is a 2-D array, `diag` extracts the `k`-th diagonal and returns a 1-D array. + k : int, optional + Diagonal in question. The default is 0. Use `k>0` for diagonals above the main diagonal, and `k<0` for diagonals + below the main diagonal. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + out : quantity or ndarray + The extracted diagonal or constructed diagonal array. + """ + if isinstance(v, Quantity): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.diag(a.value, k=k), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)): + fail_for_dimension_mismatch(v, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.diag(v.value, k=k), dim=v.dim) + elif isinstance(v, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return jnp.diag(a, k=k) * unit + return jnp.diag(v, k=k) * unit else: - return jnp.diag(a, k=k) + return jnp.diag(v, k=k) else: - return jnp.diag(a, k=k) + return jnp.diag(v, k=k) @set_module_as('brainunit.math') -def tril(a: Union[Quantity, jax.typing.ArrayLike], - k: int = 0, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: +def tril( + m: Union[Quantity, jax.typing.ArrayLike], + k: int = 0, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ Lower triangle of an array. - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional + Return a copy of a matrix with the elements above the `k`-th diagonal zeroed. + For quantities or arrays with ``ndim`` exceeding 2, `tril` will apply to the final two axes. + + Parameters + ---------- + m : quantity or ndarray + Input array. + k : int, optional + Diagonal above which to zero elements. `k = 0` is the main diagonal, `k < 0` is below it, and `k > 0` is above. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + Lower triangle of `m`, of the same shape and data-type as `m`. """ - if isinstance(a, Quantity): + if isinstance(m, Quantity): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.tril(a.value, k=k), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)): + fail_for_dimension_mismatch(m, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.tril(m.value, k=k), dim=m.dim) + elif isinstance(m, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return jnp.tril(a, k=k) * unit + return jnp.tril(m, k=k) * unit else: - return jnp.tril(a, k=k) + return jnp.tril(m, k=k) else: - return jnp.tril(a, k=k) + return jnp.tril(m, k=k) @set_module_as('brainunit.math') -def triu(a: Union[Quantity, jax.typing.ArrayLike], - k: int = 0, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: +def triu( + m: Union[Quantity, jax.typing.ArrayLike], + k: int = 0, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ - Upper triangle of an array. + Upper triangle of a quantity or an array. + + Return a copy of an array with the elements below the `k`-th diagonal + zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the + final two axes. - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional + Please refer to the documentation for `tril` for further details. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + See Also + -------- + tril : lower triangle of an array """ - if isinstance(a, Quantity): + if isinstance(m, Quantity): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.triu(a.value, k=k), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)): + fail_for_dimension_mismatch(m, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.triu(m.value, k=k), dim=m.dim) + elif isinstance(m, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return jnp.triu(a, k=k) * unit + return jnp.triu(m, k=k) * unit else: - return jnp.triu(a, k=k) + return jnp.triu(m, k=k) else: - return jnp.triu(a, k=k) + return jnp.triu(m, k=k) @set_module_as('brainunit.math') -def empty_like(a: Union[Quantity, jax.typing.ArrayLike], - dtype: Optional[jax.typing.DTypeLike] = None, - shape: Any = None, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: - """ - Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. - else return an array of `a` with uninitialized values. - - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +def empty_like( + prototype: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, + shape: Any = None, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ - if isinstance(a, Quantity): + Return a new quantity or array with the same shape and type as a given array. + + Parameters + ---------- + prototype : quantity or ndarray + The shape and data-type of `prototype` define these same attributes of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + shape : int or tuple of ints, optional + Overrides the shape of the result. If not given, `prototype.shape` is used. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + out : quantity or ndarray + Array of uninitialized (arbitrary) data with the same shape and type as `prototype`. + """ + if isinstance(prototype, Quantity): if unit is not None: assert isinstance(unit, Unit) - fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)): + fail_for_dimension_mismatch(prototype, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.empty_like(prototype.value, dtype=dtype), dim=prototype.dim) + elif isinstance(prototype, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) - return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + return jnp.empty_like(prototype, dtype=dtype, shape=shape) * unit else: - return jnp.empty_like(a, dtype=dtype, shape=shape) + return jnp.empty_like(prototype, dtype=dtype, shape=shape) else: - return jnp.empty_like(a, dtype=dtype, shape=shape) + return jnp.empty_like(prototype, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def ones_like(a: Union[Quantity, jax.typing.ArrayLike], - dtype: Optional[jax.typing.DTypeLike] = None, - shape: Any = None, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: +def ones_like( + a: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, + shape: Any = None, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ - Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. - else return an array of `a` filled with 1. + Return a quantity or an array of ones with the same shape and type as a given array. - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional + Parameters + ---------- + a : quantity or ndarray + The shape and data-type of `a` define these same attributes of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + shape : int or tuple of ints, optional + Overrides the shape of the result. If not given, `a.shape` is used. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + Array of ones with the same shape and type as `a`. """ if isinstance(a, Quantity): if unit is not None: @@ -420,22 +511,30 @@ def ones_like(a: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def zeros_like(a: Union[Quantity, jax.typing.ArrayLike], - dtype: Optional[jax.typing.DTypeLike] = None, - shape: Any = None, - unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: +def zeros_like( + a: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, + shape: Any = None, + unit: Optional[Unit] = None +) -> Union[Quantity, jax.Array]: """ - Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. - else return an array of `a` filled with 0. + Return a quantity or an array of zeros with the same shape and type as a given array. - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional + Parameters + ---------- + a : quantity or ndarray + The shape and data-type of `a` define these same attributes of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + shape : int or tuple of ints, optional + Overrides the shape of the result. If not given, `a.shape` is used. + unit : Unit, optional + Unit of the returned Quantity. - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Returns + ------- + out : quantity or ndarray + Array of zeros with the same shape and type as `a`. """ if isinstance(a, Quantity): if unit is not None: @@ -466,14 +565,22 @@ def asarray( (If they have same dimension but different magnitude, the input will be converted to the provided unit.) If unit is not provided, the input will be converted to an array. - Args: - a: array_like, Quantity, or Sequence[Quantity] - dtype: data-type, optional - order: {'C', 'F', 'A', 'K'}, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + Parameters + ---------- + a : quantity, ndarray, list[Quantity], list[ndarray] + Input data, in any form that can be converted to an array. + dtype : data-type, optional + By default, the data-type is inferred from the input data. + order : {'C', 'F', 'A', 'K'}, optional + Whether to use row-major (C-style) or column-major (Fortran-style) memory representation. + Defaults to 'K', which means that the memory layout is used in the order the array elements are stored in memory. + unit : Unit, optional + Unit of the returned Quantity. + + Returns + ------- + out : quantity or array + Array interpretation of `a`. No copy is made if the input is already an array. """ if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -515,48 +622,55 @@ def asarray( @set_module_as('brainunit.math') -def arange(*args, **kwargs): +def arange( + start: Union[Quantity, jax.typing.ArrayLike] = None, + stop: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + step: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + dtype: Optional[jax.typing.DTypeLike] = None +) -> Union[Quantity, jax.Array]: """ - Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + Return evenly spaced values within a given interval. - Args: - start: number, Quantity, optional - stop: number, Quantity, optional - step: number, optional - dtype: dtype, optional - unit: Unit, optional + Parameters + ---------- + start : Quantity or array, optional + Start of the interval. The interval includes this value. The default start value is 0. + stop : Quantity or array + End of the interval. The interval does not include this value, except in some cases where `step` is not an integer + and floating point round-off affects the length of `out`. + step : Quantity or array, optional + Spacing between values. For any output `out`, this is the distance between two adjacent values, `out[i+1] - out[i]`. + The default step size is 1. + dtype : data-type, optional + The type of the output array. If `dtype` is not given, infer the data type from the other input arguments. - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + Returns + ------- + out : quantity or array + Array of evenly spaced values. """ - # arange has a bit of a complicated argument structure unfortunately - # we leave the actual checking of the number of arguments to numpy, though - # default values - start = kwargs.pop("start", 0) - step = kwargs.pop("step", 1) - stop = kwargs.pop("stop", None) - if len(args) == 1: - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - stop = args[0] - elif len(args) == 2: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - start, stop = args - elif len(args) == 3: - if start != 0: - raise TypeError("Duplicate definition of 'start'") + arg_len = len([x for x in [start, stop, step] if x is not None]) + + if arg_len == 1: if stop is not None: raise TypeError("Duplicate definition of 'stop'") - if step != 1: - raise TypeError("Duplicate definition of 'step'") - start, stop, step = args - elif len(args) > 3: + stop = start + start = 0 + elif arg_len == 2: + if start is not None and stop is None: + stop = start + start = 0 + + elif arg_len > 3: raise TypeError("Need between 1 and 3 non-keyword arguments") + # default values + if start is None: + start = 0 + if step is None: + step = 1 + if stop is None: raise TypeError("Missing stop argument.") if stop is not None and not is_unitless(stop): @@ -565,32 +679,27 @@ def arange(*args, **kwargs): fail_for_dimension_mismatch( start, stop, - error_message=( - "Start value {start} and stop value {stop} have to have the same units." - ), + error_message="Start value {start} and stop value {stop} have to have the same units.", start=start, stop=stop, ) fail_for_dimension_mismatch( stop, step, - error_message=( - "Stop value {stop} and step value {step} have to have the same units." - ), + error_message="Stop value {stop} and step value {step} have to have the same units.", stop=stop, step=step, ) + unit = getattr(stop, "dim", DIMENSIONLESS) - # start is a position-only argument in numpy 2.0 - # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only - # TODO: check whether this is still the case in the final release + if start == 0: return Quantity( jnp.arange( start=start.value if isinstance(start, Quantity) else jnp.asarray(start), stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, + dtype=dtype, ), dim=unit, ) @@ -600,7 +709,7 @@ def arange(*args, **kwargs): start.value if isinstance(start, Quantity) else jnp.asarray(start), stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, + dtype=dtype, ), dim=unit, ) @@ -616,18 +725,30 @@ def linspace( dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Quantity, jax.Array]: """ - Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. - - Args: - start: number, Quantity - stop: number, Quantity - num: int, optional - endpoint: bool, optional - retstep: bool, optional - dtype: dtype, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + Return evenly spaced numbers over a specified interval. + + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start : Quantity or array + The starting value of the sequence. + stop : Quantity or array + The end value of the sequence. + num : int, optional + Number of samples to generate. Default is 50. + endpoint : bool, optional + If True, `stop` is the last sample. Otherwise, it is not included. Default is True. + retstep : bool, optional + If True, return (`samples`, `step`), where `step` is the spacing between samples. + dtype : data-type, optional + The type of the output array. If `dtype` is not given, infer the data type from the other input arguments. + + Returns + ------- + samples : quantity or array + There are `num` equally spaced samples in the closed interval [`start`, `stop`] or the half-open interval [`start`, `stop`). """ fail_for_dimension_mismatch( start, @@ -652,18 +773,29 @@ def logspace(start: Union[Quantity, jax.typing.ArrayLike], base: Optional[float] = 10.0, dtype: Optional[jax.typing.DTypeLike] = None): """ - Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. - - Args: - start: number, Quantity - stop: number, Quantity - num: int, optional - endpoint: bool, optional - base: float, optional - dtype: dtype, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at `base ** start` (`base` to the power of `start`) and ends with `base ** stop` in `num` steps. + + Parameters + ---------- + start : Quantity or array + The starting value of the sequence. + stop : Quantity or array + The end value of the sequence. + num : int, optional + Number of samples to generate. Default is 50. + endpoint : bool, optional + If True, `stop` is the last sample. Otherwise, it is not included. Default is True. + base : float, optional + The base of the log space. The step size between the elements in `ln(samples)` is `base`. + dtype : data-type, optional + The type of the output array. If `dtype` is not given, infer the data type from the other input arguments. + + Returns + ------- + samples : quantity or array + There are `num` equally spaced samples in the closed interval [`start`, `stop`] or the half-open interval [`start`, `stop`). """ fail_for_dimension_mismatch( start, @@ -686,16 +818,27 @@ def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike], wrap: Optional[bool] = False, inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]: """ - Fill the main diagonal of the given array of `a` with `val`. + Fill the main diagonal of the given array of any dimensionality. + + For an array `a` with `a.ndim >= 2`, the diagonal is the list of locations with indices `a[i, i, ..., i]` + all identical. - Args: - a: array_like, Quantity - val: scalar, Quantity - wrap: bool, optional - unit: Unit, optional + Parameters + ---------- + a : Quantity or array + Array in which to fill the diagonal. + val : Quantity or array + Value to be written on the diagonal. Its type must be compatible with that of the array a. + wrap : bool, optional + For tall matrices in NumPy version 1.6.2 and earlier, the matrix is considered "tall" if `a.shape[0] > a.shape[1]`. + If `wrap` is True, the diagonal is "wrapped" after `a.shape[1]` and continues in the first column. + inplace : bool, optional + If True, the diagonal is filled in-place. Default is False. - Returns: - Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + Returns + ------- + out : Quantity or array + The input array with the diagonal filled. """ if isinstance(val, Quantity): if isinstance(a, Quantity): @@ -711,19 +854,29 @@ def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def array_split(ary: Union[Quantity, jax.typing.ArrayLike], - indices_or_sections: Union[int, jax.typing.ArrayLike], - axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: +def array_split( + ary: Union[Quantity, jax.typing.ArrayLike], + indices_or_sections: Union[int, jax.typing.ArrayLike], + axis: Optional[int] = 0 +) -> Union[list[Quantity], list[Array]]: """ Split an array into multiple sub-arrays. - Args: - ary: array_like, Quantity - indices_or_sections: int, array_like - axis: int, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. + Parameters + ---------- + ary : Quantity or array + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, `ary` is divided into `indices_or_sections` sub-arrays along `axis`. + If such a split is not possible, an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where along `axis` the array is split. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of Quantity or list of array + A list of sub-arrays. """ if isinstance(ary, Quantity): return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] @@ -734,26 +887,41 @@ def array_split(ary: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike], - copy: Optional[bool] = True, - sparse: Optional[bool] = False, - indexing: Optional[str] = 'xy'): +def meshgrid( + *xi: Union[Quantity, jax.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy' +) -> Union[list[Quantity], list[Array]]: """ Return coordinate matrices from coordinate vectors. - Args: - xi: array_like, Quantity - copy: bool, optional - sparse: bool, optional - indexing: str, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. + Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector fields over N-D grids, + given one-dimensional coordinate arrays x1, x2,..., xn. + + Parameters + ---------- + xi : Quantity or array + 1-D arrays representing the coordinates of a grid. + copy : bool, optional + If True (default), the returned arrays are copies. If False, the view is returned. + sparse : bool, optional + If True, return a sparse grid (meshgrid) instead of a dense grid. + indexing : {'xy', 'ij'}, optional + Cartesian ('xy', default) or matrix ('ij') indexing of output. + + Returns + ------- + X1, X2,..., XN : Quantity or array + For vectors x1, x2,..., 'xn' with lengths Ni=len(xi), return (N1, N2, N3,..., Nn) shaped arrays if indexing='ij' + or (N2, N1, N3,..., Nn) shaped arrays if indexing='xy' with the elements of xi repeated to fill the matrix along + the first dimension for x1, the second for x2 and so on. """ from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) + return [Quantity(x, dim=xi[0].dim) for x in + jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing)] elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) else: @@ -761,19 +929,31 @@ def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def vander(x: Union[Quantity, jax.typing.ArrayLike], - N: Optional[bool] = None, - increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: +def vander( + x: Union[Quantity, jax.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False +) -> Union[Quantity, jax.Array]: """ Generate a Vandermonde matrix. - Args: - x: array_like, Quantity - N: int, optional - increasing: bool, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + The Vandermonde matrix is a matrix with the terms of a geometric progression in each row. + The geometric progression is defined by the vector `x` and the number of columns `N`. + + Parameters + ---------- + x : Quantity or array + 1-D input array. + N : int, optional + Number of columns in the output. If `N` is not specified, a square array is returned (N = len(x)). + increasing : bool, optional + Order of the powers of the columns. If True, the powers increase from left to right, if False (the default), + they are reversed. + + Returns + ------- + out : Quantity or array + Vandermonde matrix. If `increasing` is False, the first column is `x^(N-1)`, the second `x^(N-2)` and so forth. """ if isinstance(x, Quantity): return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim) diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index 6408b15..2165fd3 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import annotations from collections.abc import Sequence -from typing import (Union, Optional, Tuple, List) +from typing import (Union, Optional, Tuple, List, Any) import jax import jax.numpy as jnp from jax import Array +from jax.tree_util import tree_map -from brainunit._misc import set_module_as -from .._base import Quantity +from .._base import Quantity, fail_for_dimension_mismatch +from .._misc import set_module_as __all__ = [ # array manipulation @@ -35,9 +37,9 @@ 'diagflat', 'diagonal', 'choose', 'ravel', ] + # array manipulation # ------------------ -from jax.tree_util import tree_map def _as_jax_array_(obj): @@ -48,7 +50,7 @@ def _is_leaf(a): return isinstance(a, Quantity) -def func_array_manipulation(fun, *args, return_quantity=True, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: +def func_array_manipulation(fun, *args, return_quantity=True, **kwargs) -> Any: unit = None if isinstance(args[0], Quantity): unit = args[0].dim @@ -110,19 +112,37 @@ def reshape( order: str = 'C' ) -> Union[Array, Quantity]: """ - Return a reshaped copy of an array or a Quantity. - - Args: - a: input array or Quantity to reshape - shape: integer or sequence of integers giving the new shape, which must match the - size of the input array. If any single dimension is given size ``-1``, it will be - replaced with a value such that the output has the correct size. - order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major - (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. - brainunit does not support ``order="A"``. - - Returns: - reshaped copy of input array with the specified shape. + Gives a new shape to a quantity or an array without changing its data. + + Parameters + ---------- + a : array_like, Quantity + Array to be reshaped. + shape : int or tuple of ints + The new shape should be compatible with the original shape. If + an integer, then the result will be a 1-D array of that length. + One shape dimension can be -1. In this case, the value is + inferred from the length of the array and remaining dimensions. + order : {'C', 'F', 'A'}, optional + Read the elements of `a` using this index order, and place the + elements into the reshaped array using this index order. 'C' + means to read / write the elements using C-like index order, + with the last axis index changing fastest, back to the first + axis index changing slowest. 'F' means to read / write the + elements using Fortran-like index order, with the first index + changing fastest, and the last index changing slowest. Note that + the 'C' and 'F' options take no account of the memory layout of + the underlying array, and only refer to the order of indexing. + 'A' means to read / write the elements in Fortran-like index + order if `a` is Fortran *contiguous* in memory, C-like order + otherwise. + + Returns + ------- + reshaped_array : ndarray, Quantity + This will be a new view object if possible; otherwise, it will + be a copy. Note there is no guarantee of the *memory layout* (C- or + Fortran- contiguous) of the returned array. """ return func_array_manipulation(jnp.reshape, a, shape, order=order) @@ -134,15 +154,23 @@ def moveaxis( destination: Union[int, Tuple[int, ...]] ) -> Union[Array, Quantity]: """ - Moves axes of an array to new positions. Other axes remain in their original order. + Moves axes of a quantity or an array to new positions. + Other axes remain in their original order. - Args: - a: array_like, Quantity - source: int or sequence of ints - destination: int or sequence of ints + Parameters + ---------- + a : array_like, Quantity + The array whose axes should be reordered. + source : int or sequence of int + Original positions of the axes to move. These must be unique. + destination : int or sequence of int + Destination positions for each of the original axes. These must also be + unique. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + result : ndarray, Quantity + Array with moved axes. This array is a view of the input array. """ return func_array_manipulation(jnp.moveaxis, a, source, destination) @@ -153,32 +181,47 @@ def transpose( axes: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[Array, Quantity]: """ - Returns a view of the array with axes transposed. + Permute the dimensions of a quantity or an array. - Args: - a: array_like, Quantity - axes: tuple or list of ints, optional + Parameters + ---------- + a : array_like, Quantity + Input array. + axes : list of ints, optional + By default, reverse the dimensions, otherwise permute the axes + according to the values given. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + p : ndarray, Quantity + `a` with its axes permuted. A view is returned whenever + possible. """ return func_array_manipulation(jnp.transpose, a, axes) @set_module_as('brainunit.math') def swapaxes( - a: Union[Array, Quantity], axis1: int, axis2: int + a: Union[Array, Quantity], + axis1: int, + axis2: int ) -> Union[Array, Quantity]: """ - Interchanges two axes of an array. + Interchange two axes of a quantity or an array. - Args: - a: array_like, Quantity - axis1: int - axis2: int + Parameters + ---------- + a : array_like, Quantity + Input array. + axis1 : int + First axis. + axis2 : int + Second axis. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + a_swapped : ndarray, Quantity + a new array where the axes are swapped. """ return func_array_manipulation(jnp.swapaxes, a, axis1, axis2) @@ -186,53 +229,86 @@ def swapaxes( @set_module_as('brainunit.math') def concatenate( arrays: Union[Sequence[Array], Sequence[Quantity]], - axis: Optional[int] = None + axis: Optional[int] = None, + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Array, Quantity]: """ - Join a sequence of arrays along an existing axis. + Join a sequence of quantities or arrays along an existing axis. - Args: - arrays: sequence of array_like, Quantity - axis: int, optional + Parameters + ---------- + arrays : sequence of array_like, Quantity + The arrays must have the same shape, except in the dimension corresponding + to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + dtype : dtype, optional + If provided, the concatenation will be done using this dtype. Otherwise, the + array with the highest precision will be used. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The concatenated array. The type of the array is the same as that of the + first array passed in. """ - return func_array_manipulation(jnp.concatenate, arrays, axis=axis) + return func_array_manipulation(jnp.concatenate, arrays, axis=axis, dtype=dtype) @set_module_as('brainunit.math') def stack( arrays: Union[Sequence[Array], Sequence[Quantity]], - axis: int = 0 + axis: int = 0, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Array, Quantity]: """ - Join a sequence of arrays along a new axis. + Join a sequence of quantities or arrays along a new axis. - Args: - arrays: sequence of array_like, Quantity - axis: int + Parameters + ---------- + arrays : sequence of array_like, Quantity + The arrays must have the same shape. + axis : int, optional + The axis in the result array along which the input arrays are stacked. + out : Quantity, jax.typing.ArrayLike, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what stack would have returned if no out + argument were specified. + dtype : dtype, optional + If provided, the concatenation will be done using this dtype. Otherwise, the + array with the highest precision will be used. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The stacked array has one more dimension than the input arrays. """ - return func_array_manipulation(jnp.stack, arrays, axis=axis) + return func_array_manipulation(jnp.stack, arrays, axis=axis, out=out, dtype=dtype) @set_module_as('brainunit.math') def vstack( - arrays: Union[Sequence[Array], Sequence[Quantity]] + tup: Union[Sequence[Array], Sequence[Quantity]], + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Array, Quantity]: """ - Stack arrays in sequence vertically (row wise). + Stack quantities or arrays in sequence vertically (row wise). - Args: - arrays: sequence of array_like, Quantity + Parameters + ---------- + tup : sequence of array_like, Quantity + The arrays must have the same shape along all but the first axis. + dtype : dtype, optional + If provided, the concatenation will be done using this dtype. Otherwise, the + array with the highest precision will be used. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Returns + ------- + res : ndarray, Quantity + The array formed by stacking the given arrays. """ - return func_array_manipulation(jnp.vstack, arrays) + return func_array_manipulation(jnp.vstack, tup, dtype=dtype) row_stack = vstack @@ -240,50 +316,73 @@ def vstack( @set_module_as('brainunit.math') def hstack( - arrays: Union[Sequence[Array], Sequence[Quantity]] + arrays: Union[Sequence[Array], Sequence[Quantity]], + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Array, Quantity]: """ - Stack arrays in sequence horizontally (column wise). + Stack quantities arrays in sequence horizontally (column wise). - Args: - arrays: sequence of array_like, Quantity + Parameters + ---------- + arrays : sequence of array_like, Quantity + The arrays must have the same shape along all but the second axis. + dtype : dtype, optional + If provided, the concatenation will be done using this dtype. Otherwise, the + array with the highest precision will be used. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The array formed by stacking the given arrays. """ - return func_array_manipulation(jnp.hstack, arrays) + return func_array_manipulation(jnp.hstack, arrays, dtype=dtype) @set_module_as('brainunit.math') def dstack( - arrays: Union[Sequence[Array], Sequence[Quantity]] + arrays: Union[Sequence[Array], Sequence[Quantity]], + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Array, Quantity]: """ - Stack arrays in sequence depth wise (along third axis). + Stack quantities or arrays in sequence depth wise (along third axis). - Args: - arrays: sequence of array_like, Quantity + Parameters + ---------- + arrays : sequence of array_like, Quantity + The arrays must have the same shape along all but the third axis. + dtype : dtype, optional + If provided, the concatenation will be done using this dtype. Otherwise, the + array with the highest precision will be used. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The array formed by stacking the given arrays. """ - return func_array_manipulation(jnp.dstack, arrays) + return func_array_manipulation(jnp.dstack, arrays, dtype=dtype) @set_module_as('brainunit.math') def column_stack( - arrays: Union[Sequence[Array], Sequence[Quantity]] + tup: Union[Sequence[Array], Sequence[Quantity]] ) -> Union[Array, Quantity]: """ Stack 1-D arrays as columns into a 2-D array. - Args: - arrays: sequence of 1-D or 2-D array_like, Quantity + Take a sequence of 1-D arrays and stack them as columns to make a single + 2-D array. 2-D arrays are stacked as-is, just like with hstack. + + Parameters + ---------- + tup : sequence of 1-D array_like, Quantity + 1-D arrays to stack as columns. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The array formed by stacking the given arrays. """ - return func_array_manipulation(jnp.column_stack, arrays) + return func_array_manipulation(jnp.column_stack, tup) @set_module_as('brainunit.math') @@ -293,15 +392,28 @@ def split( axis: int = 0 ) -> Union[List[Array], List[Quantity]]: """ - Split an array into multiple sub-arrays. - - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array + Split quantity or array into a list of multiple sub-arrays. + + Parameters + ---------- + a : array_like, Quantity + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided into + N equal arrays along `axis`. If such a split is not possible, an error is + raised. If `indices_or_sections` is a 1-D array of sorted integers, the + entries indicate where along `axis` the array is split. For example, + `[2, 3]` would, for `axis=0`, result in + - `a[:2]` + - `a[2:3]` + - `a[3:]` + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + res : list of ndarrays, Quantity + A list of sub-arrays. """ return func_array_manipulation(jnp.split, a, indices_or_sections, axis=axis) @@ -312,14 +424,22 @@ def dsplit( indices_or_sections: Union[int, Sequence[int]] ) -> Union[List[Array], List[Quantity]]: """ - Split array along third axis (depth). + Split a quantity or an array into multiple sub-arrays along the 3rd axis (depth). - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array + Parameters + ---------- + a : array_like, Quantity + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided into + N equal arrays along the third axis (depth). If such a split is not possible, + an error is raised. If `indices_or_sections` is a 1-D array of sorted integers, + the entries indicate where along the third axis the array is split. - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array + Returns + ------- + res : list of ndarrays, Quantity + A list of sub-arrays. """ return func_array_manipulation(jnp.dsplit, a, indices_or_sections) @@ -330,14 +450,22 @@ def hsplit( indices_or_sections: Union[int, Sequence[int]] ) -> Union[List[Array], List[Quantity]]: """ - Split an array into multiple sub-arrays horizontally (column-wise). + Split a quantity or an array into multiple sub-arrays horizontally (column-wise). - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array + Parameters + ---------- + a : array_like, Quantity + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided into + N equal arrays along the second axis. If such a split is not possible, an + error is raised. If `indices_or_sections` is a 1-D array of sorted integers, + the entries indicate where along the second axis the array is split. - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array + Returns + ------- + res : list of ndarrays, Quantity + A list of sub-arrays. """ return func_array_manipulation(jnp.hsplit, a, indices_or_sections) @@ -348,14 +476,22 @@ def vsplit( indices_or_sections: Union[int, Sequence[int]] ) -> Union[List[Array], List[Quantity]]: """ - Split an array into multiple sub-arrays vertically (row-wise). + Split a quantity or an array into multiple sub-arrays vertically (row-wise). - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array + Parameters + ---------- + a : array_like, Quantity + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided into + N equal arrays along the first axis. If such a split is not possible, an + error is raised. If `indices_or_sections` is a 1-D array of sorted integers, + the entries indicate where along the first axis the array is split. - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array + Returns + ------- + res : list of ndarrays, Quantity + A list of sub-arrays. """ return func_array_manipulation(jnp.vsplit, a, indices_or_sections) @@ -366,14 +502,19 @@ def tile( reps: Union[int, Tuple[int, ...]] ) -> Union[Array, Quantity]: """ - Construct an array by repeating A the number of times given by reps. + Construct a quantity or an array by repeating A the number of times given by reps. - Args: - A: array_like, Quantity - reps: array_like + Parameters + ---------- + A : array_like, Quantity + The input array. + reps : array_like + The number of repetitions of A along each axis. - Returns: - Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The tiled output array. """ return func_array_manipulation(jnp.tile, A, reps) @@ -382,20 +523,30 @@ def tile( def repeat( a: Union[Array, Quantity], repeats: Union[int, Tuple[int, ...]], - axis: Optional[int] = None + axis: Optional[int] = None, + total_repeat_length: Optional[int] = None ) -> Union[Array, Quantity]: """ - Repeat elements of an array. + Repeat elements of a quantity or an array. - Args: - a: array_like, Quantity - repeats: array_like - axis: int, optional + Parameters + ---------- + a : array_like, Quantity + Input array. + repeats : int or tuple of ints + The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis. + axis : int, optional + The axis along which to repeat values. By default, use the flattened input array, and return a flat output array. + total_repeat_length : int, optional + The total length of the repeated array. If `total_repeat_length` is not None, the output array + will have the length of `total_repeat_length`. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + Output array which has the same shape as `a`, except along the given axis. """ - return func_array_manipulation(jnp.repeat, a, repeats, axis=axis) + return func_array_manipulation(jnp.repeat, a, repeats, axis=axis, total_repeat_length=total_repeat_length) @set_module_as('brainunit.math') @@ -404,23 +555,57 @@ def unique( return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, - axis: Optional[int] = None -) -> Union[Array, Quantity]: - """ - Find the unique elements of an array. - - Args: - a: array_like, Quantity - return_index: bool, optional - return_inverse: bool, optional - return_counts: bool, optional - axis: int or None, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array - """ - return func_array_manipulation(jnp.unique, a, return_index=return_index, return_inverse=return_inverse, - return_counts=return_counts, axis=axis) + axis: Optional[int] = None, + *, + equal_nan: bool = False, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike, Quantity] = None +) -> list[Array, Quantity] | Array | Quantity: + """ + Find the unique elements of a quantity or an array. + + Parameters + ---------- + a : array_like, Quantity + Input array. + return_index : bool, optional + If True, also return the indices of `a` (along the specified axis, if provided) that result in the unique array. + return_inverse : bool, optional + If True, also return the indices of the unique array (for the specified axis, if provided) + that can be used to reconstruct `a`. + return_counts : bool, optional + If True, also return the number of times each unique item appears in `a`. + axis : int, optional + The axis along which to operate. If None, the array is flattened before use. Default is None. + equal_nan : bool, optional + Whether to compare NaN's as equal. If True, NaN's in `a` will be considered equal to each other in the unique array. + size : int, optional + The length of the output array. If `size` is not None, the output array will have the length of `size`. + fill_value : scalar, optional + The value to use for missing values. If `fill_value` is not None, the output array will have the length of `size`. + + Returns + ------- + res : ndarray, Quantity + The sorted unique values. + """ + if isinstance(a, Quantity): + if fill_value is not None: + fail_for_dimension_mismatch(fill_value, a) + fill_value = fill_value.value + result = jnp.unique(a.value, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, + axis=axis, equal_nan=equal_nan, size=size, fill_value=fill_value) + if isinstance(result, tuple): + output = [] + output.append(Quantity(result[0], dim=a.dim)) + for r in result[1:]: + output.append(r) + return output + else: + return Quantity(result, dim=a.dim) + else: + return jnp.unique(a, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, + axis=axis, equal_nan=equal_nan, size=size, fill_value=fill_value) @set_module_as('brainunit.math') @@ -430,15 +615,23 @@ def append( axis: Optional[int] = None ) -> Union[Array, Quantity]: """ - Append values to the end of an array. + Append values to the end of a quantity or an array. - Args: - arr: array_like, Quantity - values: array_like, Quantity - axis: int, optional + Parameters + ---------- + arr : array_like, Quantity + Values are appended to a copy of this array. + values : array_like, Quantity + These values are appended to a copy of `arr`. + It must be of the correct shape (the same shape as `arr`, excluding `axis`). + axis : int, optional + The axis along which `values` are appended. If `axis` is None, `values` is flattened before use. - Returns: - Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + A copy of `arr` with `values` appended to `axis`. Note that `append` does not occur in-place: + a new array is allocated and filled. """ return func_array_manipulation(jnp.append, arr, values, axis=axis) @@ -449,14 +642,19 @@ def flip( axis: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[Array, Quantity]: """ - Reverse the order of elements in an array along the given axis. + Reverse the order of elements in a quantity or an array along the given axis. - Args: - m: array_like, Quantity - axis: int or tuple of ints, optional + Parameters + ---------- + m : array_like, Quantity + Input array. + axis : int or tuple of ints, optional + Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + A view of `m` with the entries of axis reversed. Since a view is returned, this operation is done in constant time. """ return func_array_manipulation(jnp.flip, m, axis=axis) @@ -466,13 +664,17 @@ def fliplr( m: Union[Array, Quantity] ) -> Union[Array, Quantity]: """ - Flip array in the left/right direction. + Flip quantity or array in the left/right direction. - Args: - m: array_like, Quantity + Parameters + ---------- + m : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + A view of `m` with the columns reversed. Since a view is returned, this operation is done in constant time. """ return func_array_manipulation(jnp.fliplr, m) @@ -482,13 +684,17 @@ def flipud( m: Union[Array, Quantity] ) -> Union[Array, Quantity]: """ - Flip array in the up/down direction. + Flip quantity or array in the up/down direction. - Args: - m: array_like, Quantity + Parameters + ---------- + m : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + A view of `m` with the rows reversed. """ return func_array_manipulation(jnp.flipud, m) @@ -500,15 +706,24 @@ def roll( axis: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[Array, Quantity]: """ - Roll array elements along a given axis. - - Args: - a: array_like, Quantity - shift: int or tuple of ints - axis: int or tuple of ints, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Roll quantity or array elements along a given axis. + + Parameters + ---------- + a : array_like, Quantity + Input array. + shift : int or tuple of ints + The number of places by which elements are shifted. If a tuple, then `axis` must be a tuple of the same size, + and each of the given axes is shifted by the corresponding number. If an int while `axis` is a tuple of ints, + then the same value is used for all given axes. + axis : int or tuple of ints, optional + Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which + the original shape is restored. + + Returns + ------- + res : ndarray, Quantity + Output array, with the same shape as `a`. """ return func_array_manipulation(jnp.roll, a, shift, axis=axis) @@ -518,13 +733,17 @@ def atleast_1d( *arys: Union[Array, Quantity] ) -> Union[Array, Quantity]: """ - View inputs as arrays with at least one dimension. + View inputs as quantities or arrays with at least one dimension. - Args: - *args: array_like, Quantity + Parameters + ---------- + *args : array_like, Quantity + One or more input arrays or quantities. - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 1`. """ return func_array_manipulation(jnp.atleast_1d, *arys) @@ -534,13 +753,17 @@ def atleast_2d( *arys: Union[Array, Quantity] ) -> Union[Array, Quantity]: """ - View inputs as arrays with at least two dimensions. + View inputs as quantities or arrays with at least two dimensions. - Args: - *args: array_like, Quantity + Parameters + ---------- + *args : array_like, Quantity + One or more input arrays or quantities. - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 2`. """ return func_array_manipulation(jnp.atleast_2d, *arys) @@ -550,13 +773,17 @@ def atleast_3d( *arys: Union[Array, Quantity] ) -> Union[Array, Quantity]: """ - View inputs as arrays with at least three dimensions. + View inputs as quantities or arrays with at least three dimensions. - Args: - *args: array_like, Quantity + Parameters + ---------- + *args : array_like, Quantity + One or more input arrays or quantities. - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 3`. """ return func_array_manipulation(jnp.atleast_3d, *arys) @@ -567,14 +794,19 @@ def expand_dims( axis: int ) -> Union[Array, Quantity]: """ - Expand the shape of an array. + Expand the shape of a quantity or an array. - Args: - a: array_like, Quantity - axis: int or tuple of ints + Parameters + ---------- + a : array_like, Quantity + Input array. + axis : int + Position in the expanded axes where the new axis is placed. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + View of `a` with the number of dimensions increased by one. """ return func_array_manipulation(jnp.expand_dims, a, axis) @@ -585,14 +817,20 @@ def squeeze( axis: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[Array, Quantity]: """ - Remove single-dimensional entries from the shape of an array. + Remove single-dimensional entries from the shape of a quantity or an array. - Args: - a: array_like, Quantity - axis: None or int or tuple of ints, optional + Parameters + ---------- + a : array_like, Quantity + Input data. + axis : None or int or tuple of ints, optional + Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater + than one, an error is raised. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + An array with the same data as `a`, but with a lower dimension. """ return func_array_manipulation(jnp.squeeze, a, axis) @@ -601,24 +839,36 @@ def squeeze( def sort( a: Union[Array, Quantity], axis: Optional[int] = -1, + *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False, ) -> Union[Array, Quantity]: """ - Return a sorted copy of an array. - - Args: - a: array_like, Quantity - axis: int or None, optional - kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional - order: str or list of str, optional - stable: bool, optional - descending: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Return a sorted copy of a quantity or an array. + + Parameters + ---------- + a : array_like, Quantity + Array or quantity to be sorted. + axis : int or None, optional + Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along + the last axis. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + Sorting algorithm. The default is 'quicksort'. + order : str or list of str, optional + When `a` is a quantity, it can be a string or a sequence of strings, which is interpreted as an order the quantity + should be sorted. The default is None. + stable : bool, optional + Whether to use a stable sorting algorithm. The default is True. + descending : bool, optional + Whether to sort in descending order. The default is False. + + Returns + ------- + res : ndarray, Quantity + Sorted copy of the input array. """ return func_array_manipulation(jnp.sort, a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) @@ -627,26 +877,43 @@ def sort( def argsort( a: Union[Array, Quantity], axis: Optional[int] = -1, + *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False, ) -> Array: """ - Returns the indices that would sort an array. - - Args: - a: array_like, Quantity - axis: int or None, optional - kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional - order: str or list of str, optional - stable: bool, optional - descending: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array - """ - return func_array_manipulation(jnp.argsort, a, axis=axis, kind=kind, order=order, stable=stable, + Returns the indices that would sort an array or a quantity. + + Parameters + ---------- + a : array_like, Quantity + Array or quantity to be sorted. + axis : int or None, optional + Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along + the last axis. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + Sorting algorithm. The default is 'None'. + order : str or list of str, optional + When `a` is a quantity, it can be a string or a sequence of strings, which is interpreted as an order the quantity + should be sorted. The default is None. + stable : bool, optional + Whether to use a stable sorting algorithm. The default is True. + descending : bool, optional + Whether to sort in descending order. The default is False. + + Returns + ------- + res : ndarray + Array of indices that sort the array. + """ + return func_array_manipulation(jnp.argsort, + a, + axis=axis, + kind=kind, + order=order, + stable=stable, descending=descending) @@ -654,58 +921,114 @@ def argsort( def max( a: Union[Array, Quantity], axis: Optional[int] = None, - keepdims: bool = False + keepdims: bool = False, + initial: Optional[Union[int, float, Quantity]] = None, + where: Optional[Array] = None, ) -> Union[Array, Quantity]: """ - Return the maximum of an array or maximum along an axis. - - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional - keepdims: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array - """ - return func_array_manipulation(jnp.max, a, axis=axis, keepdims=keepdims) + Return the maximum of a quantity or an array or maximum along an axis. + + Parameters + ---------- + a : array_like, Quantity + Array or quantity containing numbers whose maximum is desired. + axis : int or None, optional + Axis or axes along which to operate. By default, flattened input is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this + option, the result will broadcast correctly against the input array. + initial : scalar, optional + The minimum value of an output element. Must be present to allow computation on empty slice. + See `numpy.ufunc.reduce`. + where : array_like, optional + Values of True indicate to calculate the ufunc at that position, values of False indicate to leave the value in the + output alone. + + Returns + ------- + res : ndarray, Quantity + Maximum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of + dimension `a.ndim - 1`. + """ + if isinstance(a, Quantity): + if initial is not None: + fail_for_dimension_mismatch(initial, a) + initial = initial.value + return Quantity(jnp.max(a.value, axis=axis, keepdims=keepdims, initial=initial, where=where), dim=a.dim) + else: + return jnp.max(a, axis=axis, keepdims=keepdims, initial=initial, where=where) @set_module_as('brainunit.math') def min( a: Union[Array, Quantity], axis: Optional[int] = None, - keepdims: bool = False + keepdims: bool = False, + initial: Optional[Union[int, float, Quantity]] = None, + where: Optional[Array] = None, ) -> Union[Array, Quantity]: """ - Return the minimum of an array or minimum along an axis. - - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional - keepdims: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array - """ - return func_array_manipulation(jnp.min, a, axis=axis, keepdims=keepdims) + Return the minimum of a quantity or an array or minimum along an axis. + + Parameters + ---------- + a : array_like, Quantity + Array or quantity containing numbers whose minimum is desired. + axis : int or None, optional + Axis or axes along which to operate. By default, flattened input is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this + option, the result will broadcast correctly against the input array. + initial : scalar, optional + The maximum value of an output element. Must be present to allow computation on empty slice. + See `numpy.ufunc.reduce`. + where : array_like, optional + Values of True indicate to calculate the ufunc at that position, values of False indicate to leave the value in the + output alone. + + Returns + ------- + res : ndarray, Quantity + Minimum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of + dimension `a.ndim - 1`. + """ + if isinstance(a, Quantity): + if initial is not None: + fail_for_dimension_mismatch(initial, a) + initial = initial.value + return Quantity(jnp.min(a.value, axis=axis, keepdims=keepdims, initial=initial, where=where), dim=a.dim) + else: + return jnp.min(a, axis=axis, keepdims=keepdims, initial=initial, where=where) @set_module_as('brainunit.math') def choose( a: Union[Array, Quantity], - choices: Sequence[Union[Array, Quantity]] + choices: Sequence[Union[Array, Quantity]], + mode: str = 'raise', ) -> Union[Array, Quantity]: """ - Use an index array to construct a new array from a set of choices. + Construct a quantity or an array from an index array and a set of arrays to choose from. - Args: - a: array_like, Quantity - choices: array_like, Quantity + Parameters + ---------- + a : array_like, Quantity + This array must be an integer array of the same shape as `choices`. The elements of `a` are used to select elements + from `choices`. + choices : sequence of array_like, Quantity + Choice arrays. `a` and all `choices` must be broadcastable to the same shape. + mode : {'raise', 'wrap', 'clip'}, optional + Specifies how indices outside [0, n-1] will be treated: + - 'raise' : raise an error (default) + - 'wrap' : wrap around + - 'clip' : clip to the range [0, n-1] - Returns: - Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The constructed array. The shape is identical to the shape of `a`, and the data type is the data type of `choices`. """ - return func_array_manipulation(jnp.choose, a, choices) + return func_array_manipulation(jnp.choose, a, choices, mode=mode) @set_module_as('brainunit.math') @@ -713,13 +1036,18 @@ def block( arrays: Sequence[Union[Array, Quantity]] ) -> Union[Array, Quantity]: """ - Assemble an nd-array from nested lists of blocks. + Assemble a quantity or an array from nested lists of blocks. - Args: - arrays: sequence of array_like, Quantity + Parameters + ---------- + arrays : sequence of array_like, Quantity + Each element in `arrays` can itself be a nested sequence of arrays, in which case the blocks in the corresponding + cells are recursively stacked as the elements of the resulting array. - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The array constructed from the given blocks. """ return func_array_manipulation(jnp.block, arrays) @@ -728,20 +1056,43 @@ def block( def compress( condition: Union[Array, Quantity], a: Union[Array, Quantity], - axis: Optional[int] = None + axis: Optional[int] = None, + *, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike] = None, ) -> Union[Array, Quantity]: """ - Return selected slices of an array along given axis. - - Args: - condition: array_like, Quantity - a: array_like, Quantity - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array - """ - return func_array_manipulation(jnp.compress, condition, a, axis=axis) + Return selected slices of a quantity or an array along given axis. + + Parameters + ---------- + condition : array_like, Quantity + An array of boolean values that selects which slices to return. If the shape of condition is not the same as `a`, + it must be broadcastable to `a`. + a : array_like, Quantity + Input array. + axis : int or None, optional + The axis along which to take slices. If axis is None, `condition` must be a 1-D array with the same length as `a`. + If axis is an integer, `condition` must be broadcastable to the same shape as `a` along all axes except `axis`. + size : int, optional + The length of the returned axis. By default, the length of the input array along the axis is used. + fill_value : scalar, optional + The value to use for elements in the output array that are not selected. If None, the output array has the same + type as `a` and is filled with zeros. + + Returns + ------- + res : ndarray, Quantity + A new array that has the same number of dimensions as `a`, and the same shape as `a` with axis `axis` removed. + """ + if isinstance(a, Quantity): + if fill_value is not None: + fail_for_dimension_mismatch(fill_value, a.dim) + fill_value = fill_value.value + else: + fill_value = 0 + return Quantity(jnp.compress(condition, a.value, axis, size=size, fill_value=fill_value), dim=a.dim) + return jnp.compress(condition, a, axis, size=size, fill_value=0) @set_module_as('brainunit.math') @@ -750,14 +1101,19 @@ def diagflat( k: int = 0 ) -> Union[Array, Quantity]: """ - Create a two-dimensional array with the flattened input as a diagonal. + Create a two-dimensional a quantity or array with the flattened input as a diagonal. - Args: - v: array_like, Quantity - k: int, optional + Parameters + ---------- + v : array_like, Quantity + Input data, which is flattened and set as the `k`-th diagonal of the output. + k : int, optional + Diagonal in question. The default is 0. - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + Returns + ------- + res : ndarray, Quantity + The 2-D output array. """ return func_array_manipulation(jnp.diagflat, v, k) @@ -768,88 +1124,139 @@ def diagflat( def argmax( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Array] = None + keepdims: Optional[bool] = None ) -> Array: """ Returns indices of the max value along an axis. - Args: - a: array_like, Quantity - axis: int, optional - out: array, optional + Parameters + ---------- + a : array_like, Quantity + Input data. + axis : int, optional + By default, the index is into the flattened array, otherwise along the specified axis. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this + option, the result will broadcast correctly against the input array. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : ndarray + Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed. """ - return func_array_manipulation(jnp.argmax, a, axis=axis, out=out, return_quantity=False) + return func_array_manipulation(jnp.argmax, a, axis=axis, keepdim=keepdims, return_quantity=False) @set_module_as('brainunit.math') def argmin( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Array] = None + keepdims: Optional[bool] = None ) -> Array: """ Returns indices of the min value along an axis. - Args: - a: array_like, Quantity - axis: int, optional - out: array, optional + Parameters + ---------- + a : array_like, Quantity + Input data. + axis : int, optional + By default, the index is into the flattened array, otherwise along the specified axis. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this + option, the result will broadcast correctly against the input array. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : ndarray + Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed. """ - return func_array_manipulation(jnp.argmin, a, axis=axis, out=out, return_quantity=False) + return func_array_manipulation(jnp.argmin, a, axis=axis, keepdims=keepdims, return_quantity=False) @set_module_as('brainunit.math') def argwhere( - a: Union[Array, Quantity] + a: Union[Array, Quantity], + *, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike] = None, ) -> Array: """ - Find indices of non-zero elements. + Find the indices of array elements that are non-zero, grouped by element. - Args: - a: array_like, Quantity + Parameters + ---------- + a : array_like, Quantity + Input data. + size : int, optional + The length of the returned axis. By default, the length of the input array along the axis is used. + fill_value : scalar, optional + The value to use for elements in the output array that are not selected. If None, the output array has the same + type as `a` and is filled with zeros. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : ndarray + The indices of elements that are non-zero. The indices are grouped by element. """ - return func_array_manipulation(jnp.argwhere, a, return_quantity=False) + return func_array_manipulation(jnp.argwhere, a, size=size, fill_value=fill_value, return_quantity=False) @set_module_as('brainunit.math') def nonzero( - a: Union[Array, Quantity] + a: Union[Array, Quantity], + *, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike] = None, ) -> Tuple[Array, ...]: """ Return the indices of the elements that are non-zero. - Args: - a: array_like, Quantity + Parameters + ---------- + a : array_like, Quantity + Input data. + size : int, optional + The length of the returned axis. By default, the length of the input array along the axis is used. + fill_value : scalar, optional + The value to use for elements in the output array that are not selected. If None, the output array has the same + type as `a` and is filled with zeros. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : tuple of ndarrays + Indices of elements that are non-zero along the specified axis. Each array in the tuple has the same shape as the + input array. """ - return func_array_manipulation(jnp.nonzero, a, return_quantity=False) + return func_array_manipulation(jnp.nonzero, a, size=size, fill_value=fill_value, return_quantity=False) @set_module_as('brainunit.math') def flatnonzero( - a: Union[Array, Quantity] + a: Union[Array, Quantity], + *, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike] = None, ) -> Array: """ - Return indices that are non-zero in the flattened version of a. + Return indices that are non-zero in the flattened version of the input quantity or array. - Args: - a: array_like, Quantity + Parameters + ---------- + a : array_like, Quantity + Input data. + size : int, optional + The length of the returned axis. By default, the length of the input array along the axis is used. + fill_value : scalar, optional + The value to use for elements in the output array that are not selected. If None, the output array has the same + type as `a` and is filled with zeros. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : ndarray + Output array, containing the indices of the elements of `a.ravel()` that are non-zero. """ - return func_array_manipulation(jnp.flatnonzero, a, return_quantity=False) + return func_array_manipulation(jnp.flatnonzero, a, size=size, fill_value=fill_value, return_quantity=False) @set_module_as('brainunit.math') @@ -857,55 +1264,107 @@ def searchsorted( a: Union[Array, Quantity], v: Union[Array, Quantity], side: str = 'left', - sorter: Optional[Array] = None + sorter: Optional[Array] = None, + *, + method: Optional[str] = 'scan' ) -> Array: """ Find indices where elements should be inserted to maintain order. - Args: - a: array_like, Quantity - v: array_like, Quantity - side: {'left', 'right'}, optional - - Returns: - jax.Array: an array (does not return a Quantity) - """ - return func_array_manipulation(jnp.searchsorted, a, v, side=side, sorter=sorter, return_quantity=False) + Find the indices into a sorted array `a` such that, if the corresponding elements in `v` were inserted before the + indices, the order of `a` would be preserved. + + Parameters + ---------- + a : array_like, Quantity + Input array. It must be sorted in ascending order. + v : array_like, Quantity + Values to insert into `a`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. If 'right', return the last such index. If + there is no suitable index, return either 0 or N (where N is the length of `a`). + sorter : 1-D array_like, optional + Optional array of integer indices that sort array `a` into ascending order. They are typically the result of + `argsort`. + method : str + One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the + implementation: 'scan' tends to be more performant on CPU (particularly when ``a`` is + very large), 'scan_unrolled' is more performant on GPU at the expense of additional compile time, + 'sort' is often more performant on accelerator backends like GPU and TPU + (particularly when ``v`` is very large), and 'compare_all' can be most performant + when ``a`` is very small. The default is 'scan'. + + Returns + ------- + out : ndarray + Array of insertion points with the same shape as `v`. + """ + return func_array_manipulation(jnp.searchsorted, a, v, side=side, sorter=sorter, method=method, return_quantity=False) @set_module_as('brainunit.math') def extract( - condition: Union[Array, Quantity], - arr: Union[Array, Quantity] -) -> Array: + condition: Array, + arr: Union[Array, Quantity], + *, + size: Optional[int] = None, + fill_value: Optional[jax.typing.ArrayLike | Quantity] = None, +) -> Array | Quantity: """ Return the elements of an array that satisfy some condition. - Args: - condition: array_like, Quantity - arr: array_like, Quantity - - Returns: - jax.Array: an array (does not return a Quantity) - """ - return func_array_manipulation(jnp.extract, condition, arr, return_quantity=False) + Parameters + ---------- + condition : array_like, Quantity + An array of boolean values that selects which elements to extract. + arr : array_like, Quantity + The array from which to extract elements. + size: int + optional static size for output. Must be specified in order for ``extract`` + to be compatible with JAX transformations like :func:`~jax.jit` or :func:`~jax.vmap`. + fill_value: array_like + if ``size`` is specified, fill padded entries with this value (default: 0). + + Returns + ------- + res : ndarray + The extracted elements. The shape of `res` is the same as that of `condition`. + """ + if isinstance(arr, Quantity): + if fill_value is not None: + fail_for_dimension_mismatch(fill_value, arr) + fill_value = fill_value.value + else: + fill_value = 0 + return Quantity(jnp.extract(condition, arr.value, size=size, fill_value=fill_value), dim=arr.dim) + return jnp.extract(condition, arr, size=size, fill_value=0) @set_module_as('brainunit.math') def count_nonzero( - a: Union[Array, Quantity], axis: Optional[int] = None + a: Union[Array, Quantity], + axis: Optional[int] = None, + keepdims: Optional[bool] = None ) -> Array: """ - Counts the number of non-zero values in the array a. + Count the number of non-zero values in the quantity or array `a`. - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional + Parameters + ---------- + a : array_like, Quantity + The array for which to count non-zeros. + axis : int, optional + The axis along which to count the non-zeros. If `None`, count non-zeros over the entire array. + keepdims : bool, optional + If this is set to `True`, the axes which are counted are left in the result as dimensions with size one. With this + option, the result will broadcast correctly against the original array. - Returns: - jax.Array: an array (does not return a Quantity) + Returns + ------- + res : ndarray + Number of non-zero values in the quantity or array along a given axis. """ - return func_array_manipulation(jnp.count_nonzero, a, axis=axis, return_quantity=False) + return func_array_manipulation(jnp.count_nonzero, a, axis=axis, keepdims=keepdims, return_quantity=False) amax = max @@ -929,14 +1388,24 @@ def diagonal( """ Return specified diagonals. - Args: - a: array_like, Quantity - offset: int, optional - axis1: int, optional - axis2: int, optional - - Returns: - Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Parameters + ---------- + a : array_like, Quantity + Array from which the diagonals are taken. + offset : int, optional + Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal (0). + axis1 : int, optional + Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to first + axis (0). + axis2 : int, optional + Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to + second axis (1). + + Returns + ------- + res : ndarray + The extracted diagonals. The shape of the output is determined by considering the shape of the input array with + the specified axis removed. """ return function_to_method(jnp.diagonal, a, offset, axis1, axis2) @@ -947,13 +1416,23 @@ def ravel( order: str = 'C' ) -> Union[jax.Array, Quantity]: """ - Return a contiguous flattened array. - - Args: - a: array_like, Quantity - order: {'C', 'F', 'A', 'K'}, optional - - Returns: - Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Return a contiguous flattened quantity or array. + + Parameters + ---------- + a : array_like, Quantity + Input array. The elements in `a` are read in the order specified by `order`, and packed as a 1-D array. + order : {'C', 'F', 'A', 'K'}, optional + The elements of `a` are read using this index order. 'C' means to index the elements in row-major, C-style order, + with the last axis index changing fastest, back to the first axis index changing slowest. 'F' means to index the + elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing + slowest. 'A' means to read the elements in Fortran-like index order if `a` is Fortran contiguous in memory, C-like + order otherwise. 'K' means to read the elements in the order they occur in memory, except for reversing the data + when strides are negative. By default, 'C' index order is used. + + Returns + ------- + res : ndarray, Quantity + The flattened quantity or array. The shape of the output is the same as `a`, but the array is 1-D. """ return function_to_method(jnp.ravel, a, order) diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py index 4bb03e3..da293e3 100644 --- a/brainunit/math/_compat_numpy_funcs_accept_unitless.py +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -13,14 +13,14 @@ # limitations under the License. # ============================================================================== -from typing import (Union) +from typing import (Union, Optional, Tuple) import jax import jax.numpy as jnp from jax import Array -from brainunit._misc import set_module_as from .._base import (Quantity, fail_for_dimension_mismatch, ) +from .._misc import set_module_as __all__ = [ # math funcs only accept unitless (unary) @@ -51,29 +51,37 @@ def funcs_only_accept_unitless_unary(func, x, *args, **kwargs): @set_module_as('brainunit.math') -def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ - Calculate the exponential of all elements in the input array. + Calculate the exponential of all elements in the input quantity or array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.exp, x) @set_module_as('brainunit.math') -def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ - Calculate 2 raised to the power of the input elements. + Calculate 2**p for all p in the input quantity or array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.exp2, x) @@ -83,11 +91,15 @@ def expm1(x: Union[Array, Quantity]) -> Array: """ Calculate the exponential of the input elements minus 1. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.expm1, x) @@ -97,11 +109,15 @@ def log(x: Union[Array, Quantity]) -> Array: """ Natural logarithm, element-wise. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log, x) @@ -111,11 +127,15 @@ def log10(x: Union[Array, Quantity]) -> Array: """ Base-10 logarithm of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log10, x) @@ -125,11 +145,15 @@ def log1p(x: Union[Array, Quantity]) -> Array: """ Natural logarithm of 1 + the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log1p, x) @@ -139,11 +163,15 @@ def log2(x: Union[Array, Quantity]) -> Array: """ Base-2 logarithm of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log2, x) @@ -153,11 +181,15 @@ def arccos(x: Union[Array, Quantity]) -> Array: """ Compute the arccosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arccos, x) @@ -167,11 +199,15 @@ def arccosh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arccosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arccosh, x) @@ -181,11 +217,15 @@ def arcsin(x: Union[Array, Quantity]) -> Array: """ Compute the arcsine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arcsin, x) @@ -195,11 +235,15 @@ def arcsinh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arcsine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arcsinh, x) @@ -209,11 +253,15 @@ def arctan(x: Union[Array, Quantity]) -> Array: """ Compute the arctangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arctan, x) @@ -223,11 +271,15 @@ def arctanh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arctangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arctanh, x) @@ -237,11 +289,15 @@ def cos(x: Union[Array, Quantity]) -> Array: """ Compute the cosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.cos, x) @@ -251,11 +307,15 @@ def cosh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic cosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.cosh, x) @@ -265,11 +325,15 @@ def sin(x: Union[Array, Quantity]) -> Array: """ Compute the sine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sin, x) @@ -279,11 +343,15 @@ def sinc(x: Union[Array, Quantity]) -> Array: """ Compute the sinc function of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sinc, x) @@ -293,11 +361,15 @@ def sinh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic sine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sinh, x) @@ -307,11 +379,15 @@ def tan(x: Union[Array, Quantity]) -> Array: """ Compute the tangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.tan, x) @@ -321,11 +397,15 @@ def tanh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic tangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.tanh, x) @@ -335,11 +415,15 @@ def deg2rad(x: Union[Array, Quantity]) -> Array: """ Convert angles from degrees to radians. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.deg2rad, x) @@ -349,11 +433,15 @@ def rad2deg(x: Union[Array, Quantity]) -> Array: """ Convert angles from radians to degrees. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.rad2deg, x) @@ -363,11 +451,15 @@ def degrees(x: Union[Array, Quantity]) -> Array: """ Convert angles from radians to degrees. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.degrees, x) @@ -377,11 +469,15 @@ def radians(x: Union[Array, Quantity]) -> Array: """ Convert angles from degrees to radians. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.radians, x) @@ -391,11 +487,15 @@ def angle(x: Union[Array, Quantity]) -> Array: """ Return the angle of the complex argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.angle, x) @@ -405,10 +505,8 @@ def angle(x: Union[Array, Quantity]) -> Array: def funcs_only_accept_unitless_binary(func, x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value + x_value = x.value if isinstance(x, Quantity) else x + y_value = y.value if isinstance(y, Quantity) else y if isinstance(x, Quantity) or isinstance(y, Quantity): fail_for_dimension_mismatch( x, @@ -430,12 +528,17 @@ def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Given the “legs” of a right triangle, return its hypotenuse. - Args: - x: array_like, Quantity - y: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.hypot, x, y) @@ -445,12 +548,17 @@ def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.arctan2, x, y) @@ -460,12 +568,17 @@ def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Logarithm of the sum of exponentiations of the inputs. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.logaddexp, x, y) @@ -475,67 +588,260 @@ def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Logarithm of the sum of exponentiations of the inputs in base-2. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.logaddexp2, x, y) @set_module_as('brainunit.math') -def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the nth percentile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.percentile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the nth percentile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.nanpercentile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the qth quantile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.quantile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the qth quantile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.nanquantile, a, q, *args, **kwargs) +def percentile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis. + + Returns the q-th percentile(s) of the array elements. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.percentile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims) + + +@set_module_as('brainunit.math') +def nanpercentile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis, while ignoring nan values. + + Returns the q-th percentile(s) of the array elements, while ignoring nan values. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.nanpercentile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims) + + +@set_module_as('brainunit.math') +def quantile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis. + + Returns the q-th percentile(s) of the array elements. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.quantile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims) + + +@set_module_as('brainunit.math') +def nanquantile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis, while ignoring nan values. + + Returns the q-th percentile(s) of the array elements, while ignoring nan values. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.nanquantile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims) diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index 65d9881..7ccb523 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -19,11 +19,10 @@ import numpy as np from jax import Array -from brainunit._misc import set_module_as from .._base import Quantity +from .._misc import set_module_as __all__ = [ - # Elementwise bit operations (unary) 'bitwise_not', 'invert', @@ -37,7 +36,7 @@ def elementwise_bit_operation_unary(func, x, *args, **kwargs): if isinstance(x, Quantity): - raise ValueError(f'Expected integers, got {x}') + raise ValueError(f'Expected arrays, got {x}') elif isinstance(x, (jax.Array, np.ndarray)): return func(x, *args, **kwargs) else: @@ -49,11 +48,15 @@ def bitwise_not(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ Compute the bit-wise NOT of an array, element-wise. - Args: - x: array_like + Parameters + ---------- + x: array_like, quantity + Input array. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_unary(jnp.bitwise_not, x) @@ -63,11 +66,15 @@ def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ Compute bit-wise inversion, or bit-wise NOT, element-wise. - Args: - x: array_like + Parameters + ---------- + x: array_like, quantity + Input array. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_unary(jnp.invert, x) @@ -78,7 +85,7 @@ def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: def elementwise_bit_operation_binary(func, x, y, *args, **kwargs): if isinstance(x, Quantity) or isinstance(y, Quantity): - raise ValueError(f'Expected integers, got {x} and {y}') + raise ValueError(f'Expected array, got {x} and {y}') elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): return func(x, y, *args, **kwargs) else: @@ -93,12 +100,17 @@ def bitwise_and( """ Compute the bit-wise AND of two arrays element-wise. - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array + Parameters + ---------- + x: array_like, quantity + Input array. + y: array_like, quantity + Input array. + + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_binary(jnp.bitwise_and, x, y) @@ -111,12 +123,17 @@ def bitwise_or( """ Compute the bit-wise OR of two arrays element-wise. - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array + Parameters + ---------- + x: array_like, quantity + Input array. + y: array_like, quantity + Input array. + + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_binary(jnp.bitwise_or, x, y) @@ -129,12 +146,17 @@ def bitwise_xor( """ Compute the bit-wise XOR of two arrays element-wise. - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array + Parameters + ---------- + x: array_like, quantity + Input array. + y: array_like, quantity + Input array. + + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_binary(jnp.bitwise_xor, x, y) @@ -147,12 +169,17 @@ def left_shift( """ Shift the bits of an integer to the left. - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array + Parameters + ---------- + x: array_like, quantity + Input array. + y: array_like, quantity + Input array. + + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_binary(jnp.left_shift, x, y) @@ -165,11 +192,16 @@ def right_shift( """ Shift the bits of an integer to the right. - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array + Parameters + ---------- + x: array_like, quantity + Input array. + y: array_like, quantity + Input array. + + Returns + ------- + out : jax.Array + Output array. """ return elementwise_bit_operation_binary(jnp.right_shift, x, y) diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index eb431e4..59bab36 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -13,23 +13,22 @@ # limitations under the License. # ============================================================================== from collections.abc import Sequence -from typing import (Union, Optional) +from typing import (Union, Optional, Tuple, Any) import jax import jax.numpy as jnp import numpy as np -from brainunit._misc import set_module_as from ._compat_numpy_get_attribute import isscalar -from .._base import (DIMENSIONLESS, - Quantity, ) +from .._base import (DIMENSIONLESS, Quantity, ) from .._base import _return_check_unitless +from .._misc import set_module_as __all__ = [ # math funcs change unit (unary) 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', - 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', + 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'sqrt', # math funcs change unit (binary) 'multiply', 'divide', 'power', 'cross', 'ldexp', @@ -56,13 +55,22 @@ def reciprocal( x: Union[Quantity, jax.typing.ArrayLike] ) -> Union[Quantity, jax.Array]: """ - Return the reciprocal of the argument. + Return the reciprocal of the argument, element-wise. - Args: - x: array_like, Quantity + Calculates ``1/x``. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + y : ndarray, Quantity + Return array. + This is a scalar if `x` is a scalar. + + This is a Quantity if the reciprocal of the unit of `x` is not dimensionless. """ return funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1, @@ -71,68 +79,133 @@ def reciprocal( @set_module_as('brainunit.math') def var( - x: Union[Quantity, jax.typing.ArrayLike], + a: Union[Quantity, jax.typing.ArrayLike], axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[Any] = None, ddof: int = 0, - keepdims: bool = False + keepdims: bool = False, + *, + where: Optional[jax.typing.ArrayLike] = None ) -> Union[Quantity, jax.Array]: """ Compute the variance along the specified axis. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. + Returns the variance of the array elements, a measure of the spread of a + distribution. The variance is computed for the flattened array by + default, otherwise over the specified axis. + + Parameters + ---------- + a : array_like, Quantity + Array containing numbers whose variance is desired. If `a` is not an + array, a conversion is attempted. + axis : None or int or tuple of ints, optional + Axis or axes along which the variance is computed. The default is to + compute the variance of the flattened array. + + If this is a tuple of ints, a variance is performed over multiple axes, + instead of a single axis or all the axes as before. + dtype : data-type, optional + Type to use in computing the variance. For arrays of integer type + the default is `float64`; for arrays of float types it is the same as + the array type. + ddof : int, optional + "Delta Degrees of Freedom": the divisor used in the calculation is + ``N - ddof``, where ``N`` represents the number of elements. By + default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `var` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in the variance. See `~numpy.ufunc.reduce` for + details. + + Returns + ------- + variance : ndarray, quantity, see dtype parameter above + If ``out=None``, returns a new array containing the variance; + otherwise, a reference to the output array is returned. + + This is a Quantity if the square of the unit of `a` is not dimensionless. """ return funcs_change_unit_unary(jnp.var, lambda x: x ** 2, - x, + a, axis=axis, + dtype=dtype, ddof=ddof, - keepdims=keepdims) + keepdims=keepdims, + where=where) @set_module_as('brainunit.math') def nanvar( x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[Union[int, Sequence[int]]] = None, + dtype: Optional[Any] = None, ddof: int = 0, - keepdims: bool = False + keepdims: bool = False, + where: Optional[jax.typing.ArrayLike] = None ) -> Union[Quantity, jax.Array]: """ - Compute the variance along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. + Compute the variance along the specified axis, while ignoring NaNs. + + Returns the variance of the array elements, a measure of the spread of + a distribution. The variance is computed for the flattened array by + default, otherwise over the specified axis. + + For all-NaN slices or slices with zero degrees of freedom, NaN is + returned and a `RuntimeWarning` is raised. + + Parameters + ---------- + x : array_like, Quantity + Array containing numbers whose variance is desired. If `a` is not an + array, a conversion is attempted. + axis : {int, tuple of int, None}, optional + Axis or axes along which the variance is computed. The default is to compute + the variance of the flattened array. + dtype : data-type, optional + Type to use in computing the variance. For arrays of integer type + the default is `float64`; for arrays of float types it is the same as + the array type. + ddof : int, optional + "Delta Degrees of Freedom": the divisor used in the calculation is + ``N - ddof``, where ``N`` represents the number of non-NaN + elements. By default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + where : array_like of bool, optional + Elements to include in the variance. See `~numpy.ufunc.reduce` for + details. + + Returns + ------- + variance : ndarray, quantity, see dtype parameter above + If `out` is None, return a new array containing the variance, + otherwise return a reference to the output array. If ddof is >= the + number of non-NaN elements in a slice or the slice contains only + NaNs, then the result for that slice is NaN. + + This is a Quantity if the square of the unit of `a` is not dimensionless. """ return funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2, x, axis=axis, + dtype=dtype, ddof=ddof, - keepdims=keepdims) - - -@set_module_as('brainunit.math') -def frexp( - x: Union[Quantity, jax.typing.ArrayLike] -) -> Union[Quantity, jax.Array]: - """ - Decompose a floating-point number into its mantissa and exponent. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. - """ - return funcs_change_unit_unary(jnp.frexp, - lambda x: x * 2 ** -1, - x) + keepdims=keepdims, + where=where) @set_module_as('brainunit.math') @@ -142,11 +215,23 @@ def sqrt( """ Compute the square root of each element. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + The values whose square-roots are required. + + Returns + ------- + y : ndarray, quantity + An array of the same shape as `x`, containing the positive + square-root of each element in `x`. If any element in `x` is + complex, a complex array is returned (and the square-roots of + negative reals are calculated). If all of the elements in `x` + are real, so is `y`, with negative elements returning ``nan``. + If `out` was provided, `y` is a reference to it. + This is a scalar if `x` is a scalar. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. + This is a Quantity if the square root of the unit of `x` is not dimensionless. """ return funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5, @@ -160,11 +245,20 @@ def cbrt( """ Compute the cube root of each element. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + The values whose cube-roots are required. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. + Returns + ------- + y : ndarray, quantity + An array of the same shape as `x`, containing the cube + cube-root of each element in `x`. + If `out` was provided, `y` is a reference to it. + This is a scalar if `x` is a scalar. + + This is a Quantity if the cube root of the unit of `x` is not dimensionless. """ return funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3), @@ -178,11 +272,18 @@ def square( """ Compute the square of each element. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input data. + + Returns + ------- + out : ndarray, quantity or scalar + Element-wise `x*x`, of the same shape and dtype as `x`. + This is a scalar if `x` is a scalar. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. + This is a Quantity if the square of the unit of `x` is not dimensionless. """ return funcs_change_unit_unary(jnp.square, lambda x: x ** 2, @@ -193,7 +294,6 @@ def square( def prod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[jax.typing.DTypeLike] = None, - out: None = None, keepdims: Optional[bool] = False, initial: Union[Quantity, jax.typing.ArrayLike] = None, where: Union[Quantity, jax.typing.ArrayLike] = None, @@ -201,24 +301,54 @@ def prod(x: Union[Quantity, jax.typing.ArrayLike], """ Return the product of array elements over a given axis. - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - keepdims: bool, optional - initial: array_like, Quantity, optional - where: array_like, Quantity, optional - promote_integers: bool, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Parameters + ---------- + x : array_like, Quantity + Input data. + axis : None or int or tuple of ints, optional + Axis or axes along which a product is performed. The default, + axis=None, will calculate the product of all the elements in the + input array. If axis is negative it counts from the last to the + first axis. + + If axis is a tuple of ints, a product is performed on all of the + axes specified in the tuple instead of a single axis or all the + axes as before. + dtype : dtype, optional + The type of the returned array, as well as of the accumulator in + which the elements are multiplied. The dtype of `a` is used by + default unless `a` has an integer dtype of less precision than the + default platform integer. In that case, if `a` is signed then the + platform integer is used while if `a` is unsigned then an unsigned + integer of the same precision as the platform integer is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result + will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `prod` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + initial : scalar, optional + The starting value for this product. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to include in the product. See `~numpy.ufunc.reduce` for details. + + Returns + ------- + product_along_axis : ndarray, see `dtype` parameter above. + An array shaped as `a` but with the specified axis removed. + Returns a reference to `out` if specified. + + This is a Quantity if the product of the unit of `x` is not dimensionless. """ if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + return x.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + return jnp.prod(x, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -226,29 +356,63 @@ def prod(x: Union[Quantity, jax.typing.ArrayLike], def nanprod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[jax.typing.DTypeLike] = None, - out: None = None, keepdims: bool = False, initial: Union[Quantity, jax.typing.ArrayLike] = None, where: Union[Quantity, jax.typing.ArrayLike] = None): """ Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - keepdims: bool, optional - initial: array_like, Quantity, optional - where: array_like, Quantity, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Parameters + ---------- + x : array_like, Quantity + Input data. + axis : None or int or tuple of ints, optional + Axis or axes along which a product is performed. The default, + axis=None, will calculate the product of all the elements in the + input array. If axis is negative it counts from the last to the + first axis. + + If axis is a tuple of ints, a product is performed on all of the + axes specified in the tuple instead of a single axis or all the + axes as before. + dtype : dtype, optional + The type of the returned array, as well as of the accumulator in + which the elements are multiplied. The dtype of `a` is used by + default unless `a` has an integer dtype of less precision than the + default platform integer. In that case, if `a` is signed then the + platform integer is used while if `a` is unsigned then an unsigned + integer of the same precision as the platform integer is used. + out : ndarray, optional + Alternative output array in which to place the result. It must have + the same shape as the expected output, but the type of the output + values will be cast if necessary. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result + will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `prod` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + initial : scalar, optional + The starting value for this product. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to include in the product. See `~numpy.ufunc.reduce` for details. + + Returns + ------- + product_along_axis : ndarray, see `dtype` parameter above. + An array shaped as `a` but with the specified axis removed. + Returns a reference to `out` if specified. + + This is a Quantity if the product of the unit of `x` is not dimensionless. """ if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + return x.nanprod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + return jnp.nanprod(x, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) product = prod @@ -257,47 +421,81 @@ def nanprod(x: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') def cumprod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - dtype: Optional[jax.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, jax.typing.ArrayLike]: + dtype: Optional[jax.typing.DTypeLike] = None) -> Union[Quantity, jax.typing.ArrayLike]: """ Return the cumulative product of elements along a given axis. - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : int, optional + Axis along which the cumulative product is computed. By default + the input is flattened. + dtype : dtype, optional + Type of the returned array, as well as of the accumulator in which + the elements are multiplied. If *dtype* is not specified, it + defaults to the dtype of `a`, unless `a` has an integer dtype with + a precision less than that of the default platform integer. In + that case, the default platform integer is used instead. + out : ndarray, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type of the resulting values will be cast if necessary. + + Returns + ------- + cumprod : ndarray, quantity + A new array holding the result is returned unless `out` is + specified, in which case a reference to out is returned. + + This is a Quantity if the product of the unit of `x` is not dimensionless. """ if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) + return x.cumprod(axis=axis, dtype=dtype) else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + return jnp.cumprod(x, axis=axis, dtype=dtype) @set_module_as('brainunit.math') -def nancumprod(x: Union[Quantity, jax.typing.ArrayLike], - axis: Optional[int] = None, - dtype: Optional[jax.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, jax.typing.ArrayLike]: +def nancumprod( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[jax.typing.DTypeLike] = None +) -> Union[Quantity, jax.typing.ArrayLike]: """ Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : int, optional + Axis along which the cumulative product is computed. By default + the input is flattened. + dtype : dtype, optional + Type of the returned array, as well as of the accumulator in which + the elements are multiplied. If *dtype* is not specified, it + defaults to the dtype of `a`, unless `a` has an integer dtype with + a precision less than that of the default platform integer. In + that case, the default platform integer is used instead. + out : ndarray, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type of the resulting values will be cast if necessary. + + Returns + ------- + cumprod : ndarray, quantity + A new array holding the result is returned unless `out` is + specified, in which case a reference to out is returned. + + This is a Quantity if the product of the unit of `x` is not dimensionless. """ if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) + return x.nancumprod(axis=axis, dtype=dtype) else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + return jnp.nancumprod(x, axis=axis, dtype=dtype) cumproduct = cumprod @@ -332,12 +530,19 @@ def multiply( """ Multiply arguments element-wise. - Args: - x: array_like, Quantity - y: array_like, Quantity + Parameters + ---------- + x, y : array_like, Quantity + Input arrays to be multiplied. + If ``x.shape != y.shape``, they must be broadcastable to a common - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Returns + ------- + out : ndarray, Quantity + The product of `x` and `y`, element-wise. + This is a scalar if both `x` and `y` are scalars. + + This is a Quantity if the product of the unit of `x` and the unit of `y` is not dimensionless. """ return funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y, @@ -352,11 +557,19 @@ def divide( """ Divide arguments element-wise. - Args: - x: array_like, Quantity + Parameters + ---------- + x, y : array_like, Quantity + Input arrays to be divided. + If ``x.shape != y.shape``, they must be broadcastable to a common + + Returns + ------- + out : ndarray, Quantity + The quotient of `x` and `y`, element-wise. + This is a scalar if both `x` and `y` are scalars. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + This is a Quantity if the product of the unit of `x` and the unit of `y` is not dimensionless. """ return funcs_change_unit_binary(jnp.divide, lambda x, y: x / y, @@ -365,38 +578,82 @@ def divide( @set_module_as('brainunit.math') def cross( - x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + a: Union[Quantity, jax.typing.ArrayLike], + b: Union[Quantity, jax.typing.ArrayLike], + axisa: int = -1, + axisb: int = -1, + axisc: int = -1, + axis: Optional[int] = None ) -> Union[Quantity, jax.typing.ArrayLike]: """ Return the cross product of two (arrays of) vectors. - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular + to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors + are defined by the last axis of `a` and `b` by default, and these axes + can have dimensions 2 or 3. Where the dimension of either `a` or `b` is + 2, the third component of the input vector is assumed to be zero and the + cross product calculated accordingly. In cases where both input vectors + have dimension 2, the z-component of the cross product is returned. + + Parameters + ---------- + a : array_like, Quantity + Components of the first vector(s). + b : array_like, Quantity + Components of the second vector(s). + axisa : int, optional + Axis of `a` that defines the vector(s). By default, the last axis. + axisb : int, optional + Axis of `b` that defines the vector(s). By default, the last axis. + axisc : int, optional + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. + axis : int, optional + If defined, the axis of `a`, `b` and `c` that defines the vector(s) + and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns + ------- + c : ndarray, Quantity + Vector cross product(s). + + This is a Quantity if the cross product of the unit of `a` and the unit of `b` is not dimensionless. """ return funcs_change_unit_binary(jnp.cross, lambda x, y: x * y, - x, y) + a, b, + axisa=axisa, axisb=axisb, axisc=axisc, axis=axis) @set_module_as('brainunit.math') def ldexp( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: jax.typing.ArrayLike ) -> Union[Quantity, jax.typing.ArrayLike]: """ - Return x1 * 2**x2, element-wise. + Returns x * 2**y, element-wise. - Args: - x: array_like, Quantity - y: array_like, Quantity + The mantissas `x` and twos exponents `y` are used to construct + floating point numbers ``x * 2**y``. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. + Parameters + ---------- + x : array_like, Quantity + Array of multipliers. + y : array_like, int + Array of twos exponents. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray, quantity or scalar + The result of ``x * 2**y``. + This is a scalar if both `x` and `y` are scalars. + + This is a Quantity if the product of the square of the unit of `x` and the unit of `y` is not dimensionless. """ return funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y, @@ -411,12 +668,22 @@ def true_divide( """ Returns a true division of the inputs, element-wise. - Args: - x: array_like, Quantity - y: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Dividend array. + y : array_like, Quantity + Divisor array. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray, quantity or scalar + The quotient ``x/y``, element-wise. + This is a scalar if both `x` and `y` are scalars. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + This is a Quantity if the division of the unit of `x` and the unit of `y` is not dimensionless. """ return funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y, @@ -427,16 +694,30 @@ def true_divide( def divmod( x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike] -) -> Union[Quantity, jax.typing.ArrayLike]: +) -> Tuple[Union[Quantity, jax.typing.ArrayLike], Union[Quantity, jax.typing.ArrayLike]]: """ Return element-wise quotient and remainder simultaneously. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ``bu.divmod(x, y)`` is equivalent to ``(x // y, x % y)``, but faster + because it avoids redundant work. It is used to implement the Python + built-in function ``divmod`` on NumPy arrays. + + Parameters + ---------- + x : array_like, Quantity + Dividend array. + y : array_like, Quantity + Divisor array. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out1 : ndarray, quantity or scalar + Element-wise quotient resulting from floor division. + This is a scalar if both `x` and `y` are scalars. + out2 : ndarray, quantity or scalar + Element-wise remainder from floor division. + This is a scalar if both `x` and `y` are scalars. """ return funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y, @@ -445,22 +726,59 @@ def divmod( @set_module_as('brainunit.math') def convolve( - x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + a: Union[Quantity, jax.typing.ArrayLike], + v: Union[Quantity, jax.typing.ArrayLike], + mode: str = 'full', + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None ) -> Union[Quantity, jax.typing.ArrayLike]: """ Returns the discrete, linear convolution of two one-dimensional sequences. - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + The convolution operator is often seen in signal processing, where it + models the effect of a linear time-invariant system on a signal [1]_. In + probability theory, the sum of two independent random variables is + distributed according to the convolution of their individual + distributions. + + If `v` is longer than `a`, the arrays are swapped before computation. + + Parameters + ---------- + a : (N,) array_like, Quantity + First one-dimensional input array. + v : (M,) array_like, Quantity + Second one-dimensional input array. + mode : {'full', 'valid', 'same'}, optional + 'full': + By default, mode is 'full'. This returns the convolution + at each point of overlap, with an output shape of (N+M-1,). At + the end-points of the convolution, the signals do not overlap + completely, and boundary effects may be seen. + 'same': + Mode 'same' returns output of length ``max(M, N)``. Boundary + effects are still visible. + 'valid': + Mode 'valid' returns output of length + ``max(M, N) - min(M, N) + 1``. The convolution product is only given + for points where the signals overlap completely. Values outside + the signal boundary have no effect. + + Returns + ------- + out : ndarray, quantity or scalar + Discrete, linear convolution of `a` and `v`. + This is a scalar if both `a` and `v` are scalars. + + This is a Quantity if the convolution of the unit of `a` and the unit of `v` is not dimensionless. """ return funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y, - x, y) + a, v, + mode=mode, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') @@ -469,12 +787,32 @@ def power(x: Union[Quantity, jax.typing.ArrayLike], """ First array elements raised to powers from second array, element-wise. - Args: - x: array_like, Quantity - y: array_like, Quantity + Raise each base in `x` to the positionally-corresponding power in + `y`. `x` and `y` must be broadcastable to the same shape. + + An integer type raised to a negative integer power will raise a + ``ValueError``. + + Negative values raised to a non-integral value will return ``nan``. + To get complex results, cast the input to complex, or specify the + ``dtype`` to be ``complex`` (see the example below). + + Parameters + ---------- + x : array_like, Quantity + The bases. + y : array_like, Quantity + The exponents. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Returns + ------- + out : ndarray, quantity or scalar + The bases in `x` raised to the exponents in `y`. + This is a scalar if both `x` and `y` are scalars. + + This is a Quantity if the unit of `x` raised to the unit of `y` is not dimensionless. """ if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim)) @@ -493,13 +831,24 @@ def floor_divide(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: """ Return the largest integer smaller or equal to the division of the inputs. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + It is equivalent to the Python ``//`` operator and pairs with the + Python ``%`` (`remainder`), function so that ``a = a % b + b * (a // b)`` + up to roundoff. + + Parameters + ---------- + x : array_like, Quantity + Numerator. + y : array_like, Quantity + Denominator. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray + out = floor(`x`/`y`) + This is a scalar if both `x` and `y` are scalars. """ if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim)) @@ -519,12 +868,33 @@ def float_power(x: Union[Quantity, jax.typing.ArrayLike], """ First array elements raised to powers from second array, element-wise. - Args: - x: array_like, Quantity - y: array_like - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Raise each base in `x` to the positionally-corresponding power in `y`. + `x` and `y` must be broadcastable to the same shape. This differs from + the power function in that integers, float16, and float32 are promoted to + floats with a minimum precision of float64 so that the result is always + inexact. The intent is that the function will return a usable result for + negative powers and seldom overflow for positive powers. + + Negative values raised to a non-integral value will return ``nan``. + To get complex results, cast the input to complex, or specify the + ``dtype`` to be ``complex`` (see the example below). + + Parameters + ---------- + x : array_like, Quantity + The bases. + y : array_like + The exponents. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray + The bases in `x` raised to the exponents in `y`. + This is a scalar if both `x` and `y` are scalars. + + This is a Quantity if the unit of `x` raised to the unit of `y` is not dimensionless. """ if isinstance(y, Quantity): assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' @@ -540,14 +910,29 @@ def float_power(x: Union[Quantity, jax.typing.ArrayLike], def remainder(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: """ - Return element-wise remainder of division. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. + Returns the element-wise remainder of division. + + Computes the remainder complementary to the `floor_divide` function. It is + equivalent to the Python modulus operator``x1 % x2`` and has the same sign + as the divisor `x2`. The MATLAB function equivalent to ``np.remainder`` + is ``mod``. + + Parameters + ---------- + x : array_like, Quantity + Dividend array. + y : array_like, Quantity + Divisor array. + If ``x1.shape != x2.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + + Returns + ------- + out : ndarray, Quantity + The element-wise remainder of the quotient ``floor_divide(x1, x2)``. + This is a scalar if both `x1` and `x2` are scalars. + + This is a Quantity if division of `x1` by `x2` is not dimensionless. """ if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim)) diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py index bb84218..9940574 100644 --- a/brainunit/math/_compat_numpy_funcs_indexing.py +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -18,13 +18,10 @@ import jax.numpy as jnp import numpy as np -from brainunit._misc import set_module_as -from .._base import (Quantity, - fail_for_dimension_mismatch, - is_unitless, ) +from .._base import (Quantity, fail_for_dimension_mismatch, is_unitless, ) +from .._misc import set_module_as __all__ = [ - # indexing funcs 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', 'triu_indices_from', 'take', 'select', @@ -34,9 +31,42 @@ # indexing funcs # -------------- @set_module_as('brainunit.math') -def where(condition: Union[bool, jax.typing.ArrayLike], - *args: Union[Quantity, jax.typing.ArrayLike], - **kwds) -> Union[Quantity, jax.Array]: +def where( + condition: Union[bool, jax.typing.ArrayLike], + *args: Union[Quantity, jax.typing.ArrayLike], + **kwds +) -> Union[Quantity, jax.Array]: + """ + where(condition, [x, y], /) + + Return elements chosen from `x` or `y` depending on `condition`. + + .. note:: + When only `condition` is provided, this function is a shorthand for + ``np.asarray(condition).nonzero()``. Using `nonzero` directly should be + preferred, as it behaves correctly for subclasses. The rest of this + documentation covers only the case where all three arguments are + provided. + + Parameters + ---------- + condition : array_like, bool, + Where True, yield `x`, otherwise yield `y`. + x, y : array_like, Quantity + Values from which to choose. `x`, `y` and `condition` need to be + broadcastable to some shape. + + Returns + ------- + out : ndarray + An array with elements from `x` where `condition` is True, and elements + from `y` elsewhere. + + See Also + -------- + choose + nonzero : The function that is called when x and y are omitted + """ condition = jnp.asarray(condition) if len(args) == 0: # nothing to do @@ -73,13 +103,19 @@ def where(condition: Union[bool, jax.typing.ArrayLike], tril_indices.__doc__ = """ Return the indices for the lower-triangle of an (n, m) array. - Args: - n: int - m: int - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] + Parameters + ---------- + n : int + The row dimension of the arrays for which the returned indices will be valid. + m : int + The column dimension of the arrays for which the returned indices will be valid. + k : int, optional + Diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal. + + Returns + ------- + out : tuple[jax.Array] + tuple of arrays """ @@ -89,12 +125,17 @@ def tril_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], """ Return the indices for the lower-triangle of an (n, m) array. - Args: - arr: array_like, Quantity - k: int, optional + Parameters + ---------- + arr : array_like, Quantity + The arrays for which the returned indices will be valid. + k : int, optional + Diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal. - Returns: - tuple[jax.Array]: tuple[array] + Returns + ------- + out : tuple[jax.Array] + tuple of arrays """ if isinstance(arr, Quantity): return jnp.tril_indices_from(arr.value, k=k) @@ -106,13 +147,19 @@ def tril_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], triu_indices.__doc__ = """ Return the indices for the upper-triangle of an (n, m) array. - Args: - n: int - m: int - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] + Parameters + ---------- + n : int + The row dimension of the arrays for which the returned indices will be valid. + m : int + The column dimension of the arrays for which the returned indices will be valid. + k : int, optional + Diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal. + + Returns + ------- + out : tuple[jax.Array] + tuple of arrays """ @@ -122,12 +169,17 @@ def triu_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], """ Return the indices for the upper-triangle of an (n, m) array. - Args: - arr: array_like, Quantity - k: int, optional + Parameters + ---------- + arr : array_like, Quantity + The arrays for which the returned indices will be valid. + k : int, optional + Diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal. - Returns: - tuple[jax.Array]: tuple[array] + Returns + ------- + out : tuple[jax.Array] + tuple of arrays """ if isinstance(arr, Quantity): return jnp.triu_indices_from(arr.value, k=k) @@ -136,29 +188,109 @@ def triu_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def take(a: Union[Quantity, jax.typing.ArrayLike], - indices: Union[Quantity, jax.typing.ArrayLike], - axis: Optional[int] = None, - mode: Optional[str] = None) -> Union[Quantity, jax.Array]: +def take( + a: Union[Quantity, jax.typing.ArrayLike], + indices: Union[Quantity, jax.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None, + unique_indices: bool = False, + indices_are_sorted: bool = False, + fill_value: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, +) -> Union[Quantity, jax.Array]: + ''' + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : array_like (Ni..., M, Nk...) + The source array. + indices : array_like (Nj...) + The indices of the values to extract. + + Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + mode : string, default="fill" + Out-of-bounds indexing mode. The default mode="fill" returns invalid values + (e.g. NaN) for out-of bounds indices (see also ``fill_value`` below). + For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`. + fill_value : optional + The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored + otherwise. Defaults to NaN for inexact types, the largest negative value for + signed types, the largest positive value for unsigned types, and True for booleans. + unique_indices : bool, default=False + If True, the implementation will assume that the indices are unique, + which can result in more efficient execution on some backends. + indices_are_sorted : bool, default=False + If True, the implementation will assume that the indices are sorted in + ascending order, which can lead to more efficient execution on some backends. + + Returns + ------- + out : ndarray (Ni..., Nj..., Nk...) + The returned array has the same type as `a`. + ''' if isinstance(a, Quantity): - return a.take(indices, axis=axis, mode=mode) + return a.take(indices, axis=axis, mode=mode, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, fill_value=fill_value) else: - return jnp.take(a, indices, axis=axis, mode=mode) + return jnp.take(a, indices, axis=axis, mode=mode, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, fill_value=fill_value) @set_module_as('brainunit.math') -def select(condlist: list[Union[jax.typing.ArrayLike]], - choicelist: Union[Quantity, jax.typing.ArrayLike], - default: int = 0) -> Union[Quantity, jax.Array]: - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): +def select( + condlist: list[Union[jax.typing.ArrayLike]], + choicelist: Union[Quantity, jax.typing.ArrayLike], + default: int = 0 +) -> Union[Quantity, jax.Array]: + ''' + Return an array drawn from elements in choicelist, depending on conditions. + + Parameters + ---------- + condlist : list of bool ndarrays + The list of conditions which determine from which array in `choicelist` + the output elements are taken. When multiple conditions are satisfied, + the first one encountered in `condlist` is used. + choicelist : list of ndarrays or Quantity + The list of arrays from which the output elements are taken. It has + to be of the same length as `condlist`. + default : scalar, optional + The element inserted in `output` when all conditions evaluate to False. + + Returns + ------- + output : ndarray, Quantity + The output at position m is the m-th element of the array in + `choicelist` where the m-th element of the corresponding array in + `condlist` is True. + ''' + if all(isinstance(choice, Quantity) for choice in choicelist): + if any(choice.dim != choicelist[0].dim for choice in choicelist): raise ValueError("All choices must have the same unit") else: return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), dim=choicelist[0].dim) - elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): + elif all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): return jnp.select(condlist, choicelist, default=default) else: raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 0e56ec4..51cb8a2 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import (Union) +from typing import (Union, Optional, Sequence) import jax import jax.numpy as jnp import numpy as np -from brainunit._misc import set_module_as -from .._base import Quantity +from .._base import Quantity, fail_for_dimension_mismatch +from .._misc import set_module_as __all__ = [ # math funcs keep unit (unary) @@ -57,11 +57,15 @@ def real(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Return the real part of the complex argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.real, x) @@ -71,11 +75,15 @@ def imag(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Return the imaginary part of the complex argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.imag, x) @@ -85,11 +93,15 @@ def conj(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Return the complex conjugate of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.conj, x) @@ -99,11 +111,15 @@ def conjugate(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.A """ Return the complex conjugate of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.conjugate, x) @@ -113,11 +129,15 @@ def negative(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Ar """ Return the negative of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.negative, x) @@ -127,11 +147,15 @@ def positive(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Ar """ Return the positive of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.positive, x) @@ -141,11 +165,15 @@ def abs(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: """ Return the absolute value of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.abs, x) @@ -155,41 +183,68 @@ def round_(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Arra """ Round an array to the nearest integer. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.round_, x) @set_module_as('brainunit.math') -def around(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def around( + x: Union[Quantity, jax.typing.ArrayLike], + decimals: int = 0, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None +) -> Union[Quantity, jax.Array]: """ Round an array to the nearest integer. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + decimals : int, optional + Number of decimal places to round to (default is 0). + out : array_like, Quantity, optional + Output array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.around, x) @set_module_as('brainunit.math') -def round(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def round( + x: Union[Quantity, jax.typing.ArrayLike], + decimals: int = 0, +) -> Union[Quantity, jax.Array]: """ Round an array to the nearest integer. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + decimals : int, optional + Number of decimal places to round to (default is 0). + out : array_like, Quantity, optional + Output array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.round, x) + return funcs_keep_unit_unary(jnp.round, x, decimals=decimals) @set_module_as('brainunit.math') @@ -197,11 +252,15 @@ def rint(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Round an array to the nearest integer. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.rint, x) @@ -211,11 +270,15 @@ def floor(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array """ Return the floor of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.floor, x) @@ -225,11 +288,15 @@ def ceil(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Return the ceiling of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.ceil, x) @@ -239,97 +306,234 @@ def trunc(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array """ Return the truncated value of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.trunc, x) @set_module_as('brainunit.math') -def fix(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def fix( + x: Union[Quantity, jax.typing.ArrayLike], +) -> Union[Quantity, jax.Array]: """ Return the nearest integer towards zero. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.fix, x) @set_module_as('brainunit.math') -def sum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def sum( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + keepdims: bool = False, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, + promote_integers: bool = True +) -> Union[Quantity, jax.Array]: """ Return the sum of the array elements. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.sum, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which a sum is performed. The default, + axis=None, will sum all of the elements of the input array. If + axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, a sum is performed on all of the axes + specified in the tuple instead of a single axis or all the axes as + before. + dtype : dtype, optional + The type of the returned array and of the accumulator in which the + elements are summed. The dtype of `a` is used by default unless `a` + has an integer dtype of less precision than the default platform + integer. In that case, if `a` is signed then the platform integer + is used while if `a` is unsigned then an unsigned integer of the + same precision as the platform integer is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `sum` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + initial : scalar, optional + Starting value for the sum. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to include in the sum. See `~numpy.ufunc.reduce` for details. + promote_integers : bool, optional + If True, and if the accumulator is an integer type, then the + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.sum, x, + axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, + where=where, promote_integers=promote_integers) @set_module_as('brainunit.math') -def nancumsum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nancumsum( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, +) -> Union[Quantity, jax.Array]: """ Return the cumulative sum of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.nancumsum, x) + return funcs_keep_unit_unary(jnp.nancumsum, x, axis=axis, dtype=dtype) @set_module_as('brainunit.math') -def nansum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nansum( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + keepdims: bool = False, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, +) -> Union[Quantity, jax.Array]: """ Return the sum of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nansum, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : {int, tuple of int, None}, optional + Axis or axes along which the sum is computed. The default is to compute the + sum of the flattened array. + dtype : data-type, optional + The type of the returned array and of the accumulator in which the + elements are summed. By default, the dtype of `a` is used. An + exception is when `a` has an integer type with less precision than + the platform (u)intp. In that case, the default will be either + (u)int32 or (u)int64 depending on whether the platform is 32 or 64 + bits. For inexact inputs, dtype must be inexact. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + If the value is anything but the default, then + `keepdims` will be passed through to the `mean` or `sum` methods + of sub-classes of `ndarray`. If the sub-classes methods + does not implement `keepdims` any exceptions will be raised. + initial : scalar, optional + Starting value for the sum. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to include in the sum. See `~numpy.ufunc.reduce` for details. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nansum, x, + axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, + where=where) @set_module_as('brainunit.math') -def cumsum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def cumsum( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, +) -> Union[Quantity, jax.Array]: """ Return the cumulative sum of the array elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.cumsum, x) + return funcs_keep_unit_unary(jnp.cumsum, x, axis=axis, dtype=dtype) @set_module_as('brainunit.math') -def ediff1d(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def ediff1d( + x: Union[Quantity, jax.typing.ArrayLike], + to_end: jax.typing.ArrayLike = None, + to_begin: jax.typing.ArrayLike = None +) -> Union[Quantity, jax.Array]: """ Return the differences between consecutive elements of the array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + to_end : array_like, optional + Number(s) to append at the end of the returned differences. + to_begin : array_like, optional + Number(s) to prepend at the beginning of the returned differences. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.ediff1d, x) + return funcs_keep_unit_unary(jnp.ediff1d, x, to_end=to_end, to_begin=to_begin) @set_module_as('brainunit.math') @@ -337,11 +541,15 @@ def absolute(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Ar """ Return the absolute value of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.absolute, x) @@ -351,179 +559,528 @@ def fabs(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] """ Return the absolute value of the argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.fabs, x) @set_module_as('brainunit.math') -def median(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def median( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + overwrite_input: bool = False, + keepdims: bool = False +) -> Union[Quantity, jax.Array]: """ Return the median of the array elements. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.median, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : {int, sequence of int, None}, optional + Axis or axes along which the medians are computed. The default + is to compute the median along a flattened version of the array. + A sequence of axes is supported since version 1.9.0. + overwrite_input : bool, optional + If True, then allow use of memory of input array `a` for + calculations. The input array will be modified by the call to + `median`. This will save memory when you do not need to preserve + the contents of the input array. Treat the input as undefined, + but it will probably be fully or partially sorted. Default is + False. If `overwrite_input` is ``True`` and `a` is not already an + `ndarray`, an error will be raised. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `arr`. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.median, x, axis=axis, overwrite_input=overwrite_input, keepdims=keepdims) @set_module_as('brainunit.math') -def nanmin(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmin( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + keepdims: bool = False, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, +) -> Union[Quantity, jax.Array]: """ Return the minimum of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nanmin, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : {int, tuple of int, None}, optional + Axis or axes along which the minimum is computed. The default is to compute + the minimum of the flattened array. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + If the value is anything but the default, then + `keepdims` will be passed through to the `min` method + of sub-classes of `ndarray`. If the sub-classes methods + does not implement `keepdims` any exceptions will be raised. + initial : scalar, optional + The maximum value of an output element. Must be present to allow + computation on empty slice. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to compare for the minimum. See `~numpy.ufunc.reduce` + for details. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nanmin, x, axis=axis, keepdims=keepdims, initial=initial, where=where) @set_module_as('brainunit.math') -def nanmax(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmax( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + keepdims: bool = False, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, +) -> Union[Quantity, jax.Array]: """ Return the maximum of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nanmax, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : {int, tuple of int, None}, optional + Axis or axes along which the minimum is computed. The default is to compute + the minimum of the flattened array. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + If the value is anything but the default, then + `keepdims` will be passed through to the `min` method + of sub-classes of `ndarray`. If the sub-classes methods + does not implement `keepdims` any exceptions will be raised. + initial : scalar, optional + The maximum value of an output element. Must be present to allow + computation on empty slice. See `~numpy.ufunc.reduce` for details. + where : array_like of bool, optional + Elements to compare for the minimum. See `~numpy.ufunc.reduce` + for details. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nanmax, x, axis=axis, keepdims=keepdims, initial=initial, where=where) @set_module_as('brainunit.math') -def ptp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def ptp( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + keepdims: bool = False, +) -> Union[Quantity, jax.Array]: """ Return the range of the array elements (maximum - minimum). - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis along which to find the peaks. By default, flatten the + array. `axis` may be negative, in + which case it counts from the last to the first axis. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + If this is a tuple of ints, a reduction is performed on multiple + axes, instead of a single axis or all the axes as before. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `ptp` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.ptp, x) + return funcs_keep_unit_unary(jnp.ptp, x, axis=axis, keepdims=keepdims) @set_module_as('brainunit.math') -def average(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def average( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + weights: Union[jax.typing.ArrayLike, None] = None, + returned: bool = False, + keepdims: bool = False +) -> Union[Quantity, jax.Array]: """ Return the weighted average of the array elements. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.average, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which to average `a`. The default, + axis=None, will average over all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, averaging is performed on all of the axes + specified in the tuple instead of a single axis or all the axes as + before. + weights : array_like, optional + An array of weights associated with the values in `a`. Each value in + `a` contributes to the average according to its associated weight. + The weights array can either be 1-D (in which case its length must be + the size of `a` along the given axis) or of the same shape as `a`. + If `weights=None`, then all data in `a` are assumed to have a + weight equal to one. The 1-D calculation is:: + + avg = sum(a * weights) / sum(weights) + + The only constraint on `weights` is that `sum(weights)` must not be 0. + returned : bool, optional + Default is `False`. If `True`, the tuple (`average`, `sum_of_weights`) + is returned, otherwise only the average is returned. + If `weights=None`, `sum_of_weights` is equivalent to the number of + elements over which the average is taken. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + *Note:* `keepdims` will not work with instances of `numpy.matrix` + or other classes whose methods do not support `keepdims`. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.average, x, axis=axis, weights=weights, returned=returned, keepdims=keepdims) @set_module_as('brainunit.math') -def mean(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def mean( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + keepdims: bool = False, *, + where: Union[jax.typing.ArrayLike, None] = None +) -> Union[Quantity, jax.Array]: """ Return the mean of the array elements. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.mean, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which the means are computed. The default is to + compute the mean of the flattened array. + + If this is a tuple of ints, a mean is performed over multiple axes, + instead of a single axis or all the axes as before. + dtype : data-type, optional + Type to use in computing the mean. For integer inputs, the default + is `float64`; for floating point inputs, it is the same as the + input dtype. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `mean` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in the mean. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.mean, x, axis=axis, dtype=dtype, keepdims=keepdims, where=where) @set_module_as('brainunit.math') -def std(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def std( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + ddof: int = 0, + keepdims: bool = False, *, + where: Union[jax.typing.ArrayLike, None] = None +) -> Union[Quantity, jax.Array]: """ Return the standard deviation of the array elements. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.std, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which the standard deviation is computed. The + default is to compute the standard deviation of the flattened array. + + If this is a tuple of ints, a standard deviation is performed over + multiple axes, instead of a single axis or all the axes as before. + dtype : dtype, optional + Type to use in computing the standard deviation. For arrays of + integer type the default is float64, for arrays of float types it is + the same as the array type. + ddof : int, optional + Means Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + By default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `std` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in the standard deviation. + See `~numpy.ufunc.reduce` for details. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.std, x, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) @set_module_as('brainunit.math') -def nanmedian(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmedian( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, tuple[int, ...], None] = None, + overwrite_input: bool = False, + keepdims: bool = False +) -> Union[Quantity, jax.Array]: """ Return the median of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nanmedian, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : {int, sequence of int, None}, optional + Axis or axes along which the medians are computed. The default + is to compute the median along a flattened version of the array. + A sequence of axes is supported since version 1.9.0. + overwrite_input : bool, optional + If True, then allow use of memory of input array `a` for + calculations. The input array will be modified by the call to + `median`. This will save memory when you do not need to preserve + the contents of the input array. Treat the input as undefined, + but it will probably be fully or partially sorted. Default is + False. If `overwrite_input` is ``True`` and `a` is not already an + `ndarray`, an error will be raised. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + If this is anything but the default value it will be passed + through (in the special case of an empty array) to the + `mean` function of the underlying array. If the array is + a sub-class and `mean` does not have the kwarg `keepdims` this + will raise a RuntimeError. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nanmedian, x, axis=axis, overwrite_input=overwrite_input, keepdims=keepdims) @set_module_as('brainunit.math') -def nanmean(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmean( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + keepdims: bool = False, *, + where: Union[jax.typing.ArrayLike, None] = None +) -> Union[Quantity, jax.Array]: """ Return the mean of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nanmean, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which the means are computed. The default is to + compute the mean of the flattened array. + + If this is a tuple of ints, a mean is performed over multiple axes, + instead of a single axis or all the axes as before. + dtype : data-type, optional + Type to use in computing the mean. For integer inputs, the default + is `float64`; for floating point inputs, it is the same as the + input dtype. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `mean` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in the mean. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nanmean, x, axis=axis, dtype=dtype, keepdims=keepdims, where=where) @set_module_as('brainunit.math') -def nanstd(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanstd( + x: Union[Quantity, jax.typing.ArrayLike], + axis: Union[int, Sequence[int], None] = None, + dtype: Union[jax.typing.DTypeLike, None] = None, + ddof: int = 0, + keepdims: bool = False, *, + where: Union[jax.typing.ArrayLike, None] = None +) -> Union[Quantity, jax.Array]: """ Return the standard deviation of the array elements, ignoring NaNs. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.nanstd, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which the standard deviation is computed. The + default is to compute the standard deviation of the flattened array. + + If this is a tuple of ints, a standard deviation is performed over + multiple axes, instead of a single axis or all the axes as before. + dtype : dtype, optional + Type to use in computing the standard deviation. For arrays of + integer type the default is float64, for arrays of float types it is + the same as the array type. + ddof : int, optional + Means Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + By default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `std` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in the standard deviation. + See `~numpy.ufunc.reduce` for details. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.nanstd, x, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, + where=where) @set_module_as('brainunit.math') -def diff(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def diff( + x: Union[Quantity, jax.typing.ArrayLike], + n: int = 1, axis: int = -1, + prepend: Union[jax.typing.ArrayLike, None] = None, + append: Union[jax.typing.ArrayLike, None] = None +) -> Union[Quantity, jax.Array]: """ Return the differences between consecutive elements of the array. - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - """ - return funcs_keep_unit_unary(jnp.diff, x) + Parameters + ---------- + x : array_like, Quantity + Input array. + n : int, optional + The number of times values are differenced. If zero, the input + is returned as-is. + axis : int, optional + The axis along which the difference is taken, default is the + last axis. + prepend, append : array_like, optional + Values to prepend or append to `a` along axis prior to + performing the difference. Scalar values are expanded to + arrays with length 1 in the direction of axis and the shape + of the input array in along all other axes. Otherwise the + dimension and shape must match `a` except along axis. + + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. + """ + return funcs_keep_unit_unary(jnp.diff, x, n=n, axis=axis, prepend=prepend, append=append) @set_module_as('brainunit.math') -def modf(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def modf( + x: Union[Quantity, jax.typing.ArrayLike], +) -> Union[Quantity, jax.Array]: """ Return the fractional and integer parts of the array elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x` is a Quantity, else an array. """ return funcs_keep_unit_unary(jnp.modf, x) @@ -531,8 +1088,15 @@ def modf(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array] # math funcs keep unit (binary) # ----------------------------- -def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs): +def funcs_keep_unit_binary( + func, + x1, x2, + *args, + check_same_dim=True, + **kwargs): if isinstance(x1, Quantity) and isinstance(x2, Quantity): + if check_same_dim: + fail_for_dimension_mismatch(x1, x2, func.__name__) return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): return func(x1, x2, *args, **kwargs) @@ -541,17 +1105,24 @@ def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs): @set_module_as('brainunit.math') -def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: +def fmod(x1: Union[Quantity, jax.Array], + x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: """ Return the element-wise remainder of division. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ + # TODO: Consider different unit of x1 and x2 return funcs_keep_unit_binary(jnp.fmod, x1, x2) @@ -560,13 +1131,19 @@ def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union """ Return the element-wise modulus of division. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ + # TODO: Consider different unit of x1 and x2 return funcs_keep_unit_binary(jnp.mod, x1, x2) @@ -575,29 +1152,41 @@ def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> """ Return a copy of the first array elements with the sign of the second array. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.copysign, x1, x2) @set_module_as('brainunit.math') -def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: +def heaviside(x1: Union[Quantity, jax.Array], + x2: jax.typing.ArrayLike) -> Union[Quantity, jax.Array]: """ Compute the Heaviside step function. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ - return funcs_keep_unit_binary(jnp.heaviside, x1, x2) + x1 = x1.value if isinstance(x1, Quantity) else x1 + return jnp.heaviside(x1, x2) @set_module_as('brainunit.math') @@ -605,12 +1194,17 @@ def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> U """ Element-wise maximum of array elements. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.maximum, x1, x2) @@ -620,12 +1214,17 @@ def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> U """ Element-wise minimum of array elements. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.minimum, x1, x2) @@ -635,12 +1234,17 @@ def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Unio """ Element-wise maximum of array elements ignoring NaNs. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.fmax, x1, x2) @@ -650,12 +1254,17 @@ def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Unio """ Element-wise minimum of array elements ignoring NaNs. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.fmin, x1, x2) @@ -665,12 +1274,17 @@ def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union """ Return the least common multiple of `x1` and `x2`. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.lcm, x1, x2) @@ -680,12 +1294,17 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union """ Return the greatest common divisor of `x1` and `x2`. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x1: array_like, Quantity + Input array. + x2: array_like, Quantity + Input array. - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Returns + ------- + out : jax.Array, Quantity + Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ return funcs_keep_unit_binary(jnp.gcd, x1, x2) @@ -693,12 +1312,14 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union # math funcs keep unit (n-ary) # ---------------------------- @set_module_as('brainunit.math') -def interp(x: Union[Quantity, jax.typing.ArrayLike], - xp: Union[Quantity, jax.typing.ArrayLike], - fp: Union[Quantity, jax.typing.ArrayLike], - left: Union[Quantity, jax.typing.ArrayLike] = None, - right: Union[Quantity, jax.typing.ArrayLike] = None, - period: Union[Quantity, jax.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: +def interp( + x: Union[Quantity, jax.typing.ArrayLike], + xp: Union[Quantity, jax.typing.ArrayLike], + fp: Union[Quantity, jax.typing.ArrayLike], + left: Union[Quantity, jax.typing.ArrayLike] = None, + right: Union[Quantity, jax.typing.ArrayLike] = None, + period: Union[Quantity, jax.typing.ArrayLike] = None +) -> Union[Quantity, jax.Array]: """ One-dimensional linear interpolation. @@ -713,32 +1334,30 @@ def interp(x: Union[Quantity, jax.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. """ - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value + if isinstance(x, Quantity) and isinstance(xp, Quantity) and isinstance(fp, Quantity): + fail_for_dimension_mismatch(x, xp) + fail_for_dimension_mismatch(x, fp) + unit = x.dim + if isinstance(left, Quantity): + fail_for_dimension_mismatch(x, left) + left = left.value + if isinstance(right, Quantity): + fail_for_dimension_mismatch(x, right) + right = right.value + if isinstance(period, Quantity): + fail_for_dimension_mismatch(x, period) + period = period.value + return Quantity(jnp.interp(x.value, xp.value, fp.value, left, right, period), dim=unit) else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, dim=unit) - else: - return result + return jnp.interp(x, xp, fp, left, right, period) @set_module_as('brainunit.math') -def clip(a: Union[Quantity, jax.typing.ArrayLike], - a_min: Union[Quantity, jax.typing.ArrayLike], - a_max: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def clip( + a: Union[Quantity, jax.typing.ArrayLike], + a_min: Union[Quantity, jax.typing.ArrayLike], + a_max: Union[Quantity, jax.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: """ Clip (limit) the values in an array. @@ -751,22 +1370,10 @@ def clip(a: Union[Quantity, jax.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. """ unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, dim=unit) + if isinstance(a, Quantity) and isinstance(a_min, Quantity) and isinstance(a_max, Quantity): + fail_for_dimension_mismatch(a, a_min) + fail_for_dimension_mismatch(a, a_max) + unit = a.dim + return Quantity(jnp.clip(a.value, a_min.value, a_max.value), dim=unit) else: - return result + return jnp.clip(a, a_min, a_max) diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py index 9e985f3..1a1a2e4 100644 --- a/brainunit/math/_compat_numpy_funcs_logic.py +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import annotations + from typing import (Union, Optional) import jax @@ -19,9 +21,8 @@ import numpy as np from jax import Array -from brainunit._misc import set_module_as -from .._base import (Quantity, - fail_for_dimension_mismatch, ) +from .._base import (Quantity, fail_for_dimension_mismatch) +from .._misc import set_module_as __all__ = [ # logic funcs (unary) @@ -40,7 +41,7 @@ def logic_func_unary(func, x, *args, **kwargs): if isinstance(x, Quantity): - raise ValueError(f'Expected booleans, got {x}') + raise ValueError(f'Expected arrays, got {x}') elif isinstance(x, (jax.Array, np.ndarray)): return func(x, *args, **kwargs) else: @@ -51,61 +52,107 @@ def logic_func_unary(func, x, *args, **kwargs): def all( x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - out: Optional[Array] = None, keepdims: bool = False, where: Optional[Array] = None ) -> Union[bool, Array]: """ Test whether all array elements along a given axis evaluate to True. - Args: - a: array_like - axis: int, optional - out: array, optional - keepdims: bool, optional - where: array_like of bool, optional - - Returns: - Union[bool, jax.Array]: bool or array + Parameters + ---------- + x : array_like, Quantity + Input array or object that can be converted to an array. + axis : None or int or tuple of ints, optional + Axis or axes along which a logical AND reduction is performed. + The default (``axis=None``) is to perform a logical AND over all + the dimensions of the input array. `axis` may be negative, in + which case it counts from the last to the first axis. + + If this is a tuple of ints, a reduction is performed on multiple + axes, instead of a single axis or all the axes as before. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `all` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in checking for all `True` values. + + Returns + ------- + all : ndarray, bool + A new boolean or array is returned unless `out` is specified, + in which case a reference to `out` is returned. """ - return logic_func_unary(jnp.all, x, axis=axis, out=out, keepdims=keepdims, where=where) + return logic_func_unary(jnp.all, x, axis=axis, keepdims=keepdims, where=where) @set_module_as('brainunit.math') def any( x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - out: Optional[Array] = None, keepdims: bool = False, where: Optional[Array] = None ) -> Union[bool, Array]: """ Test whether any array element along a given axis evaluates to True. - Args: - a: array_like - axis: int, optional - out: array, optional - keepdims: bool, optional - where: array_like of bool, optional - - Returns: - Union[bool, jax.Array]: bool or array + Parameters + ---------- + x : array_like, Quantity + Input array or object that can be converted to an array. + axis : None or int or tuple of ints, optional + Axis or axes along which a logical AND reduction is performed. + The default (``axis=None``) is to perform a logical AND over all + the dimensions of the input array. `axis` may be negative, in + which case it counts from the last to the first axis. + + If this is a tuple of ints, a reduction is performed on multiple + axes, instead of a single axis or all the axes as before. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `all` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + where : array_like of bool, optional + Elements to include in checking for all `True` values. + + Returns + ------- + any : ndarray, bool + A new boolean or array is returned unless `out` is specified, + in which case a reference to `out` is returned. """ - return logic_func_unary(jnp.any, x, axis=axis, out=out, keepdims=keepdims, where=where) + return logic_func_unary(jnp.any, x, axis=axis, keepdims=keepdims, where=where) @set_module_as('brainunit.math') -def logical_not(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: +def logical_not( + x: Union[Quantity, jax.typing.ArrayLike], +) -> Union[bool, Array]: """ Compute the truth value of NOT x element-wise. - Args: - x: array_like - out: array, optional + Parameters + ---------- + x : array_like, Quantity + Input array or object that can be converted to an array. - Returns: - Union[bool, jax.Array]: bool or array + Returns + ------- + logical_not : ndarray, bool + A new boolean or array is returned unless `out` is specified, + in which case a reference to `out` is returned. """ return logic_func_unary(jnp.logical_not, x) @@ -131,130 +178,298 @@ def logic_func_binary(func, x, y, *args, **kwargs): @set_module_as('brainunit.math') def equal( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[bool, Array]: """ - Return (x == y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Return (x == y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.equal, x, y) + return logic_func_binary(jnp.equal, x, y, *args, **kwargs) @set_module_as('brainunit.math') def not_equal( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[bool, Array]: """ - Return (x != y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + not_equal(x, y, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) + + Return (x != y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.not_equal, x, y) + return logic_func_binary(jnp.not_equal, x, y, *args, **kwargs) @set_module_as('brainunit.math') def greater( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[bool, Array]: """ - Return (x > y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + greater(x, y, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) + + Return the truth value of (x > y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.greater, x, y) + return logic_func_binary(jnp.greater, x, y, *args, **kwargs) @set_module_as('brainunit.math') def greater_equal( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Return (x >= y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + greater_equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Return the truth value of (x >= y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : bool or ndarray of bool + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.greater_equal, x, y) + return logic_func_binary(jnp.greater_equal, x, y, *args, **kwargs) @set_module_as('brainunit.math') def less( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[bool, Array]: """ - Return (x < y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + less(x, y, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) + + Return the truth value of (x < y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x1.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.less, x, y) + return logic_func_binary(jnp.less, x, y, *args, **kwargs) @set_module_as('brainunit.math') def less_equal( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Return (x <= y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + less_equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Return the truth value of (x <= y) element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Output array, element-wise comparison of `x` and `y`. + Typically of type bool, unless ``dtype=object`` is passed. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.less_equal, x, y) + return logic_func_binary(jnp.less_equal, x, y, *args, **kwargs) @set_module_as('brainunit.math') def array_equal( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array + True if two arrays have the same shape and elements, False otherwise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + equal_nan : bool + Whether to compare NaN's as equal. If the dtype of a1 and a2 is + complex, values will be considered equal if either the real or the + imaginary component of a given value is ``nan``. + + Returns + ------- + b : bool + Returns True if the arrays are equal. """ - return logic_func_binary(jnp.array_equal, x, y) + return logic_func_binary(jnp.array_equal, x, y, *args, **kwargs) @set_module_as('brainunit.math') @@ -266,100 +481,222 @@ def isclose( equal_nan: bool = False ) -> Union[bool, Array]: """ - Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. - - Args: - a: array_like, Quantity - b: array_like, Quantity - rtol: float, optional - atol: float, optional - equal_nan: bool, optional - - Returns: - Union[bool, jax.Array]: bool or array + Returns a boolean array where two arrays are element-wise equal within a + tolerance. + + The tolerance values are positive, typically very small numbers. The + relative difference (`rtol` * abs(`b`)) and the absolute difference + `atol` are added together to compare against the absolute difference + between `a` and `b`. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays to compare. + rtol : float + The relative tolerance parameter (see Notes). + atol : float + The absolute tolerance parameter (see Notes). + equal_nan : bool + Whether to compare NaN's as equal. If True, NaN's in `a` will be + considered equal to NaN's in `b` in the output array. + + Returns + ------- + out : array_like + Returns a boolean array of where `a` and `b` are equal within the + given tolerance. If both `a` and `b` are scalars, returns a single + boolean value. """ - return logic_func_binary(jnp.isclose, x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + if rtol is None: + rtol = (1e-05 / x.dim.value[0]) + if atol is None: + atol = (1e-08 / x.dim.value[0]) + return jnp.isclose(x.value, y.value, rtol=rtol, atol=atol, equal_nan=equal_nan) + return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) @set_module_as('brainunit.math') def allclose( x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike], - rtol: float = 1e-05, - atol: float = 1e-08, + rtol: float | Quantity = 1e-05, + atol: float | Quantity = 1e-08, equal_nan: bool = False ) -> Union[bool, Array]: """ - Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. - - Args: - a: array_like, Quantity - b: array_like, Quantity - rtol: float, optional - atol: float, optional - equal_nan: bool, optional - - Returns: - bool: boolean result + Returns True if two arrays are element-wise equal within a tolerance. + + The tolerance values are positive, typically very small numbers. The + relative difference (`rtol` * abs(`b`)) and the absolute difference + `atol` are added together to compare against the absolute difference + between `a` and `b`. + + NaNs are treated as equal if they are in the same place and if + ``equal_nan=True``. Infs are treated as equal if they are in the same + place and of the same sign in both arrays. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays to compare. + rtol : float + The relative tolerance parameter (see Notes). + atol : float + The absolute tolerance parameter (see Notes). + equal_nan : bool + Whether to compare NaN's as equal. If True, NaN's in `a` will be + considered equal to NaN's in `b` in the output array. + + Returns + ------- + allclose : bool + Returns True if the two arrays are equal within the given + tolerance; False otherwise. """ - return logic_func_binary(jnp.allclose, x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + if rtol is None: + rtol = (1e-05 / x.dim.value[0]) + if atol is None: + atol = (1e-08 / x.dim.value[0]) + return jnp.allclose(x.value, y.value, rtol=rtol, atol=atol, equal_nan=equal_nan) + else: + return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) @set_module_as('brainunit.math') def logical_and( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array + logical_and(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Compute the truth value of x AND y element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Input arrays. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or bool + Boolean result of the logical AND operation applied to the elements + of `x` and `y`; the shape is determined by broadcasting. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.logical_and, x, y) + return logic_func_binary(jnp.logical_and, x, y, *args, **kwargs) @set_module_as('brainunit.math') def logical_or( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array + logical_or(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Compute the truth value of x OR y element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Logical OR is applied to the elements of `x` and `y`. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or bool + Boolean result of the logical OR operation applied to the elements + of `x` and `y`; the shape is determined by broadcasting. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.logical_or, x, y) + return logic_func_binary(jnp.logical_or, x, y, *args, **kwargs) @set_module_as('brainunit.math') def logical_xor( x: Union[Quantity, jax.typing.ArrayLike], - y: Union[Quantity, jax.typing.ArrayLike] + y: Union[Quantity, jax.typing.ArrayLike], + *args, + **kwargs ) -> Union[ bool, Array]: """ - Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array + logical_xor(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) + + Compute the truth value of x XOR y, element-wise. + + Parameters + ---------- + x, y : array_like, Quantity + Logical XOR is applied to the elements of `x` and `y`. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : bool or ndarray of bool + Boolean result of the logical XOR operation applied to the elements + of `x` and `y`; the shape is determined by broadcasting. + This is a scalar if both `x` and `y` are scalars. """ - return logic_func_binary(jnp.logical_xor, x, y) + return logic_func_binary(jnp.logical_xor, x, y, *args, **kwargs) diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py index e74b480..deafa01 100644 --- a/brainunit/math/_compat_numpy_funcs_match_unit.py +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -19,9 +19,9 @@ import numpy as np from jax import Array -from brainunit._misc import set_module_as from .._base import (Quantity, fail_for_dimension_mismatch, ) +from .._misc import set_module_as __all__ = [ # math funcs match unit (binary) @@ -56,52 +56,132 @@ def funcs_match_unit_binary(func, x, y, *args, **kwargs): @set_module_as('brainunit.math') def add( x: Union[Quantity, Array], - y: Union[Quantity, Array] + y: Union[Quantity, Array], + *args, + **kwargs ) -> Union[Quantity, Array]: """ - Add arguments element-wise. + add(x1, x2, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) - Args: - x: array_like, Quantity - y: array_like, Quantity + Add arguments element-wise. - Returns: - Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. + Parameters + ---------- + x, y : array_like, Quantity + The arrays to be added. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + add : ndarray or scalar + The sum of `x` and `y`, element-wise. + This is a scalar if both `x` and `y` are scalars. """ - return funcs_match_unit_binary(jnp.add, x, y) + return funcs_match_unit_binary(jnp.add, x, y, *args, **kwargs) @set_module_as('brainunit.math') def subtract( x: Union[Quantity, Array], - y: Union[Quantity, Array] + y: Union[Quantity, Array], + *args, + **kwargs ) -> Union[Quantity, Array]: """ - Subtract arguments element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. + subtract(x1, x2, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) + + Subtract arguments, element-wise. + + Parameters + ---------- + x, y : array_like + The arrays to be subtracted from each other. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + subtract : ndarray + The difference of `x` and `y`, element-wise. + This is a scalar if both `x` and `y` are scalars. """ - return funcs_match_unit_binary(jnp.subtract, x, y) + return funcs_match_unit_binary(jnp.subtract, x, y, *args, **kwargs) @set_module_as('brainunit.math') def nextafter( x: Union[Quantity, Array], - y: Union[Quantity, Array] + y: Union[Quantity, Array], + *args, + **kwargs ) -> Union[Quantity, Array]: """ - Return the next floating-point value after `x1` towards `x2`. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + nextafter(x, y, /, out=None, *, where=True, casting='same_kind', + order='K', dtype=None, subok=True[, signature, extobj]) + + Return the next floating-point value after x1 towards x2, element-wise. + + Parameters + ---------- + x : array_like, Quantity + Values to find the next representable value of. + y : array_like, Quantity + The direction where to look for the next representable value of `x`. + If ``x.shape != y.shape``, they must be broadcastable to a common + shape (which becomes the shape of the output). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or None, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + The next representable values of `x` in the direction of `y`. + This is a scalar if both `x` and `y` are scalars. """ - return funcs_match_unit_binary(jnp.nextafter, x, y) + return funcs_match_unit_binary(jnp.nextafter, x, y, *args, **kwargs) diff --git a/brainunit/math/_compat_numpy_funcs_remove_unit.py b/brainunit/math/_compat_numpy_funcs_remove_unit.py index 034d721..dc27e68 100644 --- a/brainunit/math/_compat_numpy_funcs_remove_unit.py +++ b/brainunit/math/_compat_numpy_funcs_remove_unit.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import (Union, Optional) +from __future__ import annotations +from typing import (Union, Optional, Sequence, Any) + +import jax import jax.numpy as jnp from jax import Array -from brainunit._misc import set_module_as -from .._base import Quantity +from .._base import Quantity, fail_for_dimension_mismatch +from .._misc import set_module_as __all__ = [ @@ -45,11 +48,16 @@ def signbit(x: Union[Array, Quantity]) -> Array: """ Returns element-wise True where signbit is set (less than zero). - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + The input value(s). - Returns: - jax.Array: an array + Returns + ------- + result : ndarray of bool + Output array, or reference to `out` if that was supplied. + This is a scalar if `x` is a scalar. """ return funcs_remove_unit_unary(jnp.signbit, x) @@ -59,111 +67,319 @@ def sign(x: Union[Array, Quantity]) -> Array: """ Returns the sign of each element in the input array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input values. - Returns: - jax.Array: an array + Returns + ------- + y : ndarray + The sign of `x`. + This is a scalar if `x` is a scalar. """ return funcs_remove_unit_unary(jnp.sign, x) @set_module_as('brainunit.math') -def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: +def histogram( + x: Union[Array, Quantity], + bins: jax.typing.ArrayLike = 10, + range: Optional[Sequence[jax.typing.ArrayLike | Quantity]] = None, + weights: Optional[jax.typing.ArrayLike] = None, + density: Optional[bool] = None +) -> tuple[Array, Array]: """ Compute the histogram of a set of data. - Args: - x: array_like, Quantity - - Returns: - tuple[jax.Array]: Tuple of arrays (hist, bin_edges) + Parameters + ---------- + x : array_like, Quantity + Input data. The histogram is computed over the flattened array. + bins : int or sequence of scalars or str, optional + If `bins` is an int, it defines the number of equal-width + bins in the given range (10, by default). If `bins` is a + sequence, it defines a monotonically increasing array of bin edges, + including the rightmost edge, allowing for non-uniform bin widths. + + If `bins` is a string, it defines the method used to calculate the + optimal bin width, as defined by `histogram_bin_edges`. + range : (float, float), (Quantity, Quantity) optional + The lower and upper range of the bins. If not provided, range + is simply ``(a.min(), a.max())``. Values outside the range are + ignored. The first element of the range must be less than or + equal to the second. `range` affects the automatic bin + computation as well. While bin width is computed to be optimal + based on the actual data within `range`, the bin count will fill + the entire range including portions containing no data. + weights : array_like, optional + An array of weights, of the same shape as `a`. Each value in + `a` only contributes its associated weight towards the bin count + (instead of 1). If `density` is True, the weights are + normalized, so that the integral of the density over the range + remains 1. + density : bool, optional + If ``False``, the result will contain the number of samples in + each bin. If ``True``, the result is the value of the + probability *density* function at the bin, normalized such that + the *integral* over the range is 1. Note that the sum of the + histogram values will not be equal to 1 unless bins of unity + width are chosen; it is not a probability *mass* function. + + Returns + ------- + hist : array + The values of the histogram. See `density` and `weights` for a + description of the possible semantics. + bin_edges : array of dtype float + Return the bin edges ``(length(hist)+1)``. """ - return funcs_remove_unit_unary(jnp.histogram, x) + if isinstance(x, Quantity): + if range is not None: + fail_for_dimension_mismatch(x, range[0]) + fail_for_dimension_mismatch(x, range[1]) + range = (range[0].value, range[1].value) + x = x.value + return jnp.histogram(x, bins, range, weights, density) @set_module_as('brainunit.math') -def bincount(x: Union[Array, Quantity]) -> Array: +def bincount( + x: Union[Array, Quantity], + weights: Optional[jax.typing.ArrayLike] = None, + minlength: int = 0, + *, + length: Optional[int] = None +) -> Array: """ - Count number of occurrences of each value in array of non-negative integers. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array + Count number of occurrences of each value in array of non-negative ints. + + The number of bins (of size 1) is one larger than the largest value in + `x`. If `minlength` is specified, there will be at least this number + of bins in the output array (though it will be longer if necessary, + depending on the contents of `x`). + Each bin gives the number of occurrences of its index value in `x`. + If `weights` is specified the input array is weighted by it, i.e. if a + value ``n`` is found at position ``i``, ``out[n] += weight[i]`` instead + of ``out[n] += 1``. + + Parameters + ---------- + x : array_like, Quantity, 1 dimension, nonnegative ints + Input array. + weights : array_like, optional + Weights, array of the same shape as `x`. + minlength : int, optional + A minimum number of bins for the output array. + + Returns + ------- + out : ndarray of ints + The result of binning the input array. + The length of `out` is equal to ``bu.amax(x)+1``. """ - return funcs_remove_unit_unary(jnp.bincount, x) + return funcs_remove_unit_unary(jnp.bincount, x, weights=weights, minlength=minlength, length=length) # math funcs remove unit (binary) # ------------------------------- def funcs_remove_unit_binary(func, x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value if isinstance(x, Quantity) or isinstance(y, Quantity): - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) else: return func(x, y, *args, **kwargs) @set_module_as('brainunit.math') -def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: +def corrcoef( + x: Union[Array, Quantity], + y: Union[Array, Quantity], + rowvar: bool = True +) -> Array: """ Return Pearson product-moment correlation coefficients. - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - jax.Array: an array + Please refer to the documentation for `cov` for more detail. The + relationship between the correlation coefficient matrix, `R`, and the + covariance matrix, `C`, is + + .. math:: R_{ij} = \\frac{ C_{ij} } { \\sqrt{ C_{ii} C_{jj} } } + + The values of `R` are between -1 and 1, inclusive. + + Parameters + ---------- + x : array_like, Quantity + A 1-D or 2-D array containing multiple variables and observations. + Each row of `x` represents a variable, and each column a single + observation of all those variables. Also see `rowvar` below. + y : array_like, Quantity, optional + An additional set of variables and observations. `y` has the same + shape as `x`. + rowvar : bool, optional + If `rowvar` is True (default), then each row represents a + variable, with observations in the columns. Otherwise, the relationship + is transposed: each column represents a variable, while the rows + contain observations. + + Returns + ------- + R : ndarray + The correlation coefficient matrix of the variables. """ - return funcs_remove_unit_binary(jnp.corrcoef, x, y) + return funcs_remove_unit_binary(jnp.corrcoef, x, y, rowvar=rowvar) @set_module_as('brainunit.math') -def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: +def correlate( + a: Union[Array, Quantity], + v: Union[Array, Quantity], + mode: str = 'valid', + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None +) -> Array: """ - Cross-correlation of two sequences. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - jax.Array: an array + Cross-correlation of two 1-dimensional sequences. + + This function computes the correlation as generally defined in signal + processing texts: + + .. math:: c_k = \sum_n a_{n+k} \cdot \overline{v}_n + + with a and v sequences being zero-padded where necessary and + :math:`\overline x` denoting complex conjugation. + + Parameters + ---------- + a, v : array_like, Quantity + Input sequences. + mode : {'valid', 'same', 'full'}, optional + Refer to the `convolve` docstring. Note that the default + is 'valid', unlike `convolve`, which uses 'full'. + precision : Optional. Either ``None``, which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value + (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``), a + string (e.g. 'highest' or 'fastest', see the + ``jax.default_matmul_precision`` context manager), or a tuple of two + :class:`~jax.lax.Precision` enums or strings indicating precision of + ``lhs`` and ``rhs``. + preferred_element_type : Optional. Either ``None``, which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns + ------- + out : ndarray + Discrete cross-correlation of `a` and `v`. """ - return funcs_remove_unit_binary(jnp.correlate, x, y) + return funcs_remove_unit_binary(jnp.correlate, a, v, mode=mode, precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') -def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: +def cov( + m: Union[Array, Quantity], + y: Optional[Union[Array, Quantity]] = None, + rowvar: bool = True, + bias: bool = False, + ddof: Optional[int] = None, + fweights: Optional[jax.typing.ArrayLike] = None, + aweights: Optional[jax.typing.ArrayLike] = None +) -> Array: """ - Covariance matrix. - - Args: - x: array_like, Quantity - y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) - - Returns: - jax.Array: an array + Estimate a covariance matrix, given data and weights. + + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element :math:`C_{ij}` is the covariance of + :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance + of :math:`x_i`. + + See the notes for an outline of the algorithm. + + Parameters + ---------- + m : array_like, Quantity + A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. Also see `rowvar` below. + y : array_like, Quantity or optional + An additional set of variables and observations. `y` has the same form + as that of `m`. + rowvar : bool, optional + If `rowvar` is True (default), then each row represents a + variable, with observations in the columns. Otherwise, the relationship + is transposed: each column represents a variable, while the rows + contain observations. + bias : bool, optional + Default normalization (False) is by ``(N - 1)``, where ``N`` is the + number of observations given (unbiased estimate). If `bias` is True, + then normalization is by ``N``. These values can be overridden by using + the keyword ``ddof`` in numpy versions >= 1.5. + ddof : int, optional + If not ``None`` the default value implied by `bias` is overridden. + Note that ``ddof=1`` will return the unbiased estimate, even if both + `fweights` and `aweights` are specified, and ``ddof=0`` will return + the simple average. See the notes for the details. The default value + is ``None``. + fweights : array_like, int, optional + 1-D array of integer frequency weights; the number of times each + observation vector should be repeated. + aweights : array_like, optional + 1-D array of observation vector weights. These relative weights are + typically large for observations considered "important" and smaller for + observations considered less "important". If ``ddof=0`` the array of + weights can be used to assign probabilities to observation vectors. + + Returns + ------- + out : ndarray + The covariance matrix of the variables. """ - return funcs_remove_unit_binary(jnp.cov, x, y) + return funcs_remove_unit_binary(jnp.cov, m, y, rowvar=rowvar, bias=bias, ddof=ddof, fweights=fweights, + aweights=aweights) @set_module_as('brainunit.math') -def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: +def digitize( + x: Union[Array, Quantity], + bins: Union[Array, Quantity], + right: bool = False +) -> Array: """ Return the indices of the bins to which each value in input array belongs. - Args: - x: array_like, Quantity - bins: array_like, Quantity - - Returns: - jax.Array: an array + ========= ============= ============================ + `right` order of bins returned index `i` satisfies + ========= ============= ============================ + ``False`` increasing ``bins[i-1] <= x < bins[i]`` + ``True`` increasing ``bins[i-1] < x <= bins[i]`` + ``False`` decreasing ``bins[i-1] > x >= bins[i]`` + ``True`` decreasing ``bins[i-1] >= x > bins[i]`` + ========= ============= ============================ + + If values in `x` are beyond the bounds of `bins`, 0 or ``len(bins)`` is + returned as appropriate. + + Parameters + ---------- + x : array_like, Quantity + Input array to be binned. Prior to NumPy 1.10.0, this array had to + be 1-dimensional, but can now have any shape. + bins : array_like, Quantity + Array of bins. It has to be 1-dimensional and monotonic. + right : bool, optional + Indicating whether the intervals include the right or the left bin + edge. Default behavior is (right==False) indicating that the interval + does not include the right edge. The left bin end is open in this + case, i.e., bins[i-1] <= x < bins[i] is the default behavior for + monotonically increasing bins. + + Returns + ------- + indices : ndarray of ints + Output array of indices, of same shape as `x`. """ - return funcs_remove_unit_binary(jnp.digitize, x, bins) + return funcs_remove_unit_binary(jnp.digitize, x, bins, right=right) diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py index ae4df30..b79fa01 100644 --- a/brainunit/math/_compat_numpy_funcs_window.py +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -16,10 +16,9 @@ import jax.numpy as jnp from jax import Array -from brainunit._misc import set_module_as +from .._misc import set_module_as __all__ = [ - # window funcs 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', ] diff --git a/brainunit/math/_compat_numpy_get_attribute.py b/brainunit/math/_compat_numpy_get_attribute.py index 1595191..10fc71a 100644 --- a/brainunit/math/_compat_numpy_get_attribute.py +++ b/brainunit/math/_compat_numpy_get_attribute.py @@ -18,8 +18,8 @@ import jax.numpy as jnp import numpy as np -from brainunit._misc import set_module_as from .._base import Quantity +from .._misc import set_module_as __all__ = [ # getting attribute funcs diff --git a/brainunit/math/_compat_numpy_linear_algebra.py b/brainunit/math/_compat_numpy_linear_algebra.py index a1ced94..05ad3c8 100644 --- a/brainunit/math/_compat_numpy_linear_algebra.py +++ b/brainunit/math/_compat_numpy_linear_algebra.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import (Union) +from typing import (Union, Any, Optional) +import jax import jax.numpy as jnp from jax import Array -from brainunit._misc import set_module_as from ._compat_numpy_funcs_change_unit import funcs_change_unit_binary from ._compat_numpy_funcs_keep_unit import funcs_keep_unit_unary from .._base import Quantity +from .._misc import set_module_as __all__ = [ @@ -34,84 +35,179 @@ # -------------- @set_module_as('brainunit.math') -def dot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def dot( + a: Union[Array, Quantity], + b: Union[Array, Quantity], + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None +) -> Union[Array, Quantity]: """ Dot product of two arrays or quantities. - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Parameters + ---------- + a : array_like, Quantity + First argument. + b : array_like, Quantity + Second argument. + precision : either ``None`` (default), + which means the default precision for the backend, a :class:`~jax.lax.Precision` + enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) + or a tuple of two such values indicating precision of ``a`` and ``b``. + preferred_element_type : either ``None`` (default) + which means the default accumulation type for the input types, or a datatype, + indicating to accumulate results to and return a result with that datatype. + + Returns + ------- + output : ndarray, Quantity + array containing the dot product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.dot, lambda x, y: x * y, - a, b) + a, b, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') -def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def vdot( + a: Union[Array, Quantity], + b: Union[Array, Quantity], + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None +) -> Union[Array, Quantity]: """ - Return the dot product of two vectors or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Perform a conjugate multiplication of two 1D vectors. + + Parameters + ---------- + a : array_like, Quantity + First argument. + b : array_like, Quantity + Second argument. + precision : either ``None`` (default), + which means the default precision for the backend, a :class:`~jax.lax.Precision` + enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) + or a tuple of two such values indicating precision of ``a`` and ``b``. + preferred_element_type : either ``None`` (default) + which means the default accumulation type for the input types, or a datatype, + indicating to accumulate results to and return a result with that datatype. + + Returns + ------- + output : ndarray, Quantity + array containing the dot product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y, - a, b) + a, b, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') -def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def inner( + a: Union[Array, Quantity], + b: Union[Array, Quantity], + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None +) -> Union[Array, Quantity]: """ Inner product of two arrays or quantities. - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Parameters + ---------- + a : array_like, Quantity + First argument. + b : array_like, Quantity + Second argument. + precision : either ``None`` (default), + which means the default precision for the backend, a :class:`~jax.lax.Precision` + enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) + or a tuple of two such values indicating precision of ``a`` and ``b``. + preferred_element_type : either ``None`` (default) + which means the default accumulation type for the input types, or a datatype, + indicating to accumulate results to and return a result with that datatype. + + Returns + ------- + output : ndarray, Quantity + array containing the inner product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.inner, lambda x, y: x * y, - a, b) + a, b, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') -def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def outer( + a: Union[Array, Quantity], + b: Union[Array, Quantity], + out: Optional[Any] = None +) -> Union[Array, Quantity]: """ Compute the outer product of two vectors or quantities. - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Parameters + ---------- + a : array_like, Quantity + First argument. + b : array_like, Quantity + Second argument. + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. + If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + output : ndarray, Quantity + array containing the outer product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.outer, lambda x, y: x * y, - a, b) + a, b, + out=out) @set_module_as('brainunit.math') -def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def kron( + a: Union[Array, Quantity], + b: Union[Array, Quantity] +) -> Union[Array, Quantity]: """ Compute the Kronecker product of two arrays or quantities. - Args: - a: array_like, Quantity - b: array_like, Quantity + Parameters + ---------- + a : array_like, Quantity + First input. + b : array_like, Quantity + Second input. - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Returns + ------- + output : ndarray, Quantity + Kronecker product of `a` and `b`. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.kron, lambda x, y: x * y, @@ -119,32 +215,88 @@ def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Q @set_module_as('brainunit.math') -def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: +def matmul( + a: Union[Array, Quantity], + b: Union[Array, Quantity], + *, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None +) -> Union[Array, Quantity]: """ Matrix product of two arrays or quantities. - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Parameters + ---------- + a : array_like, Quantity + First argument. + b : array_like, Quantity + Second argument. + precision : either ``None`` (default), + which means the default precision for the backend, a :class:`~jax.lax.Precision` + enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) + or a tuple of two such values indicating precision of ``a`` and ``b``. + preferred_element_type : either ``None`` (default) + which means the default accumulation type for the input types, or a datatype, + indicating to accumulate results to and return a result with that datatype. + + Returns + ------- + output : ndarray, Quantity + array containing the matrix product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + This is a Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. """ return funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y, - a, b) + a, b, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') -def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: +def trace( + a: Union[Array, Quantity], + offset: int = 0, + axis1: int = 0, + axis2: int = 1, + dtype: Optional[jax.typing.DTypeLike] = None, +) -> Union[Array, Quantity]: """ - Return the sum of the diagonal elements of a matrix or quantity. - - Args: - a: array_like, Quantity - offset: int, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. + Return the sum along diagonals of the array. + + If `a` is 2-D, the sum along its diagonal with the given offset + is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i. + + If `a` has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-arrays whose traces are returned. + The shape of the resulting array is the same as that of `a` with `axis1` + and `axis2` removed. + + Parameters + ---------- + a : array_like, Quantity + Input array, from which the diagonals are taken. + offset : int, optional + Offset of the diagonal from the main diagonal. Can be both positive + and negative. Defaults to 0. + axis1, axis2 : int, optional + Axes to be used as the first and second axis of the 2-D sub-arrays + from which the diagonals should be taken. Defaults are the first two + axes of `a`. + dtype : dtype, optional + Determines the data-type of the returned array and of the accumulator + where the elements are summed. If dtype has the value None and `a` is + of integer type of precision less than the default integer + precision, then the default integer precision is used. Otherwise, + the precision is the same as that of `a`. + + Returns + ------- + sum_along_diagonals : ndarray + If `a` is 2-D, the sum along the diagonal is returned. If `a` has + larger dimensions, then an array of sums along diagonals is returned. + + This is a Quantity if `a` is a Quantity, else an array. """ - return funcs_keep_unit_unary(jnp.trace, a) + return funcs_keep_unit_unary(jnp.trace, a, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 400dc5a..f261a4a 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from __future__ import annotations + from collections.abc import Sequence -from typing import (Callable, Union, Tuple) +from typing import (Callable, Union, Tuple, Any, Optional) import jax import jax.numpy as jnp @@ -22,7 +25,6 @@ from jax import Array from jax._src.numpy.lax_numpy import _einsum -from brainunit._misc import set_module_as from ._compat_numpy_array_manipulation import func_array_manipulation from ._compat_numpy_funcs_change_unit import funcs_change_unit_binary from ._compat_numpy_funcs_keep_unit import funcs_keep_unit_unary @@ -30,7 +32,8 @@ Quantity, fail_for_dimension_mismatch, is_unitless, - get_unit, ) + get_dim, ) +from .._misc import set_module_as __all__ = [ @@ -43,7 +46,7 @@ # more 'broadcast_arrays', 'broadcast_shapes', 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', - 'rot90', 'tensordot', + 'rot90', 'tensordot', 'frexp', ] # constants @@ -77,16 +80,9 @@ def iinfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.iinfo: # ---- @set_module_as('brainunit.math') def broadcast_arrays(*args: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, list[Array]]: - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.dim != args[0].dim for arg in args): - raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) - elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): - return jnp.broadcast_arrays(*args) - else: - raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + leaves, tree = jax.tree.flatten(args) + leaves = jnp.broadcast_arrays(*leaves) + return jax.tree.unflatten(tree, leaves) broadcast_shapes = jnp.broadcast_shapes @@ -100,30 +96,39 @@ def einsum( out: None = None, optimize: Union[str, bool] = "optimal", precision: jax.lax.PrecisionLike = None, - preferred_element_type: Union[jax.typing.DTypeLike, None] = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None, _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, ) -> Union[jax.Array, Quantity]: """ Evaluates the Einstein summation convention on the operands. - Args: - subscripts: string containing axes names separated by commas. - *operands: sequence of one or more arrays or quantities corresponding to the subscripts. - optimize: determine whether to optimize the order of computation. In JAX - this defaults to ``"optimize"`` which produces optimized expressions via - the opt_einsum_ package. - precision: either ``None`` (default), which means the default precision for - the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``). - preferred_element_type: either ``None`` (default), which means the default - accumulation type for the input types, or a datatype, indicating to - accumulate results to and return a result with that datatype. - out: unsupported by JAX - _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. - This parameter is experimental, and may be removed without warning at any time. - - Returns: - array containing the result of the einstein summation. + Parameters + ---------- + subscripts : str + string containing axes names separated by commas. + *operands : array_like, Quantity, optional + sequence of one or more arrays or quantities corresponding to the subscripts. + optimize : {False, True, 'optimal'}, optional + determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision : either ``None`` (default), + which means the default precision for the backend + a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type : either ``None`` (default) + which means the default accumulation type for the input types, + or a datatype, indicating to accumulate results to and return a result with that datatype. + out : {None}, optional + This parameter is not supported in JAX. + _dot_general : callable, optional + optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns + ------- + output : Quantity or jax.Array + The calculation based on the Einstein summation convention. """ operands = (subscripts, *operands) if out is not None: @@ -142,13 +147,11 @@ def einsum( from jax._src.numpy.lax_numpy import _default_poly_einsum_handler contract_path = _default_poly_einsum_handler - operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) + operands, contractions = contract_path(*operands, einsum_call=True, use_blas=True, optimize=optimize) unit = None for i in range(len(contractions) - 1): if contractions[i][4] == 'False': - fail_for_dimension_mismatch( Quantity([], dim=unit), operands[i + 1], 'einsum' ) @@ -191,14 +194,46 @@ def gradient( """ Computes the gradient of a scalar field. - Args: - f: input array. - *varargs: list of scalar fields to compute the gradient. - axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. - edge_order: order of the edge used for the finite difference computation. The default is 1. - - Returns: - array containing the gradient of the scalar field. + Return the gradient of an N-dimensional array. + + The gradient is computed using second order accurate central differences + in the interior points and either first or second order accurate one-sides + (forward or backwards) differences at the boundaries. + The returned gradient hence has the same shape as the input array. + + Parameters + ---------- + f : array_like, Quantity + An N-dimensional array containing samples of a scalar function. + varargs : list of scalar or array, optional + Spacing between f values. Default unitary spacing for all dimensions. + Spacing can be specified using: + + 1. single scalar to specify a sample distance for all dimensions. + 2. N scalars to specify a constant sample distance for each dimension. + i.e. `dx`, `dy`, `dz`, ... + 3. N arrays to specify the coordinates of the values along each + dimension of F. The length of the array must match the size of + the corresponding dimension + 4. Any combination of N scalars/arrays with the meaning of 2. and 3. + + If `axis` is given, the number of varargs must equal the number of axes. + Default: 1. + edge_order : {1, 2}, optional + Gradient is calculated using N-th order accurate differences + at the boundaries. Default: 1. + axis : None or int or tuple of ints, optional + Gradient is calculated only along the given axis or axes + The default (axis = None) is to calculate the gradient for all the axes + of the input array. axis may be negative, in which case it counts from + the last to the first axis. + + Returns + ------- + gradient : ndarray or list of ndarray or Quantity + A list of ndarrays (or a single ndarray if there is only one dimension) + corresponding to the derivatives of f with respect to each dimension. + Each derivative has the same shape as f. """ if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") @@ -209,13 +244,13 @@ def gradient( else: return jnp.gradient(f) elif len(varargs) == 1: - unit = get_unit(f) / get_unit(varargs[0]) + unit = get_dim(f) / get_dim(varargs[0]) if unit is None or unit == DIMENSIONLESS: return jnp.gradient(f, varargs[0], axis=axis) else: return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] else: - unit_list = [get_unit(f) / get_unit(v) for v in varargs] + unit_list = [get_dim(f) / get_dim(v) for v in varargs] f = f.value if isinstance(f, Quantity) else f varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] result_list = jnp.gradient(f, *varargs, axis=axis) @@ -224,22 +259,40 @@ def gradient( @set_module_as('brainunit.math') def intersect1d( - ar1: Union[jax.typing.ArrayLike], - ar2: Union[jax.typing.ArrayLike], + ar1: Union[jax.typing.ArrayLike, Quantity], + ar2: Union[jax.typing.ArrayLike, Quantity], assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: """ Find the intersection of two arrays. - Args: - ar1: input array. - ar2: input array. - assume_unique: if True, the input arrays are both assumed to be unique. - return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. - - Returns: - array containing the intersection of the two arrays. + Return the sorted, unique values that are in both of the input arrays. + + Parameters + ---------- + ar1, ar2 : array_like, Quantity + Input arrays. Will be flattened if not already 1D. + assume_unique : bool + If True, the input arrays are both assumed to be unique, which + can speed up the calculation. If True but ``ar1`` or ``ar2`` are not + unique, incorrect results and out-of-bounds indices could result. + Default is False. + return_indices : bool + If True, the indices which correspond to the intersection of the two + arrays are returned. The first instance of a value is used if there are + multiple. Default is False. + + Returns + ------- + intersect1d : ndarray, Quantity + Sorted 1D array of common and unique elements. + comm1 : ndarray + The indices of the first occurrences of the common values in `ar1`. + Only provided if `return_indices` is True. + comm2 : ndarray + The indices of the first occurrences of the common values in `ar2`. + Only provided if `return_indices` is True. """ fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') unit = None @@ -250,7 +303,7 @@ def intersect1d( result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) if return_indices: if unit is not None: - return (Quantity(result[0], dim=unit), result[1], result[2]) + return Quantity(result[0], dim=unit), result[1], result[2] else: return result else: @@ -263,23 +316,59 @@ def intersect1d( @set_module_as('brainunit.math') def nan_to_num( x: Union[jax.typing.ArrayLike, Quantity], - nan: float = 0.0, - posinf: float = jnp.inf, - neginf: float = -jnp.inf + nan: float | Quantity = None, + posinf: float | Quantity = None, + neginf: float | Quantity = None ) -> Union[jax.Array, Quantity]: """ - Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. - - Args: - x: input array. - nan: value to replace NaNs with. - posinf: value to replace positive infinity with. - neginf: value to replace negative infinity with. - - Returns: - array with NaNs replaced by zero and infinities replaced by large finite numbers. + Replace NaN with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. + + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + Parameters + ---------- + x : scalar, array_like or Quantity + Input data. + nan : int, float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + posinf : int, float, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + neginf : int, float, optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. + + Returns + ------- + out : ndarray, Quantity + `x`, with the non-finite values replaced. If `copy` is False, this may + be `x` itself. """ - return funcs_keep_unit_unary(jnp.nan_to_num, x, nan=nan, posinf=posinf, neginf=neginf) + if isinstance(x, Quantity): + nan = Quantity(0.0, dim=x.dim) if nan is None else nan + posinf = Quantity(jnp.finfo(x.dtype).max, dim=x.dim) if posinf is None else posinf + neginf = Quantity(jnp.finfo(x.dtype).min, dim=x.dim) if neginf is None else neginf + return Quantity(jnp.nan_to_num(x.value, nan=nan.value, posinf=posinf.value, neginf=neginf.value), dim=x.dim) + else: + nan = 0.0 if nan is None else nan + posinf = jnp.inf if posinf is None else posinf + neginf = -jnp.inf if neginf is None else neginf + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) @set_module_as('brainunit.math') @@ -290,16 +379,26 @@ def rot90( ) -> Union[ jax.Array, Quantity]: """ - Return the index of the maximum value in an array, ignoring NaNs. + Rotate an array by 90 degrees in the plane specified by axes. - Args: - a: array like, Quantity. - axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. - out: output array, optional. - keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + Rotation direction is from the first towards the second axis. - Returns: - index of the maximum value in the array. + Parameters + ---------- + m : array_like, Quantity + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes : (2,) array_like + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + y : ndarray, Quantity + A rotated view of `m`. + + This is a quantity if `m` is a quantity. """ return funcs_keep_unit_unary(jnp.rot90, m, k=k, axes=axes) @@ -308,56 +407,152 @@ def rot90( def tensordot( a: Union[jax.typing.ArrayLike, Quantity], b: Union[jax.typing.ArrayLike, Quantity], - axes: Union[int, Tuple[int, int]] = 2 + axes: Union[int, Tuple[int, int]] = 2, + precision: Any = None, + preferred_element_type: Optional[jax.typing.DTypeLike] = None ) -> Union[jax.Array, Quantity]: """ - Return the index of the minimum value in an array, ignoring NaNs. - - Args: - a: array like, Quantity. - axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. - out: output array, optional. - keepdims: if True, the result is broadcast to the input array with the same number of dimensions. - - Returns: - index of the minimum value in the array. + Compute tensor dot product along specified axes. + + Given two tensors, `a` and `b`, and an array_like object containing + two array_like objects, ``(a_axes, b_axes)``, sum the products of + `a`'s and `b`'s elements (components) over the axes specified by + ``a_axes`` and ``b_axes``. The third argument can be a single non-negative + integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions + of `a` and the first ``N`` dimensions of `b` are summed over. + + Parameters + ---------- + a, b : array_like, Quantity + Tensors to "dot". + + axes : int or (2,) array_like + * integer_like + If an int N, sum over the last N axes of `a` and the first N axes + of `b` in order. The sizes of the corresponding axes must match. + * (2,) array_like + Or, a list of axes to be summed over, first sequence applying to `a`, + second to `b`. Both elements array_like must be of the same length. + precision : Optional. Either ``None``, which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value + (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``), a + string (e.g. 'highest' or 'fastest', see the + ``jax.default_matmul_precision`` context manager), or a tuple of two + :class:`~jax.lax.Precision` enums or strings indicating precision of + ``lhs`` and ``rhs``. + preferred_element_type : Optional. Either ``None``, which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns + ------- + output : ndarray, Quantity + The tensor dot product of the input. + + This is a quantity if the product of the units of `a` and `b` is not dimensionless. """ - return funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y, a, b, axes=axes) + return funcs_change_unit_binary(jnp.tensordot, + lambda x, y: x * y, + a, b, + axes=axes, + precision=precision, + preferred_element_type=preferred_element_type) @set_module_as('brainunit.math') def nanargmax( a: Union[jax.typing.ArrayLike, Quantity], - axis: int = None + axis: int = None, + keepdims: bool = False ) -> jax.Array: """ - Rotate an array by 90 degrees in the plane specified by axes. - - Args: - m: array like, Quantity. - k: number of times the array is rotated by 90 degrees. - axes: plane of rotation. Default is the last two axes. - - Returns: - rotated array. + Return the indices of the maximum values in the specified axis ignoring + NaNs. For all-NaN slices ``ValueError`` is raised. Warning: the + results cannot be trusted if a slice contains only NaNs and -Infs. + + + Parameters + ---------- + a : array_like, Quantity + Input data. + axis : int, optional + Axis along which to operate. By default flattened input is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the array. + + Returns + ------- + index_array : ndarray + An array of indices or a single index value. """ - return func_array_manipulation(jnp.nanargmax, a, return_quantity=False, axis=axis) + return func_array_manipulation(jnp.nanargmax, + a, + return_quantity=False, + axis=axis, + keepdims=keepdims) @set_module_as('brainunit.math') def nanargmin( a: Union[jax.typing.ArrayLike, Quantity], - axis: int = None + axis: int = None, + keepdims: bool = False ) -> jax.Array: """ - Compute tensor dot product along specified axes for arrays. + Return the indices of the minimum values in the specified axis ignoring + NaNs. For all-NaN slices ``ValueError`` is raised. Warning: the results + cannot be trusted if a slice contains only NaNs and Infs. + + Parameters + ---------- + a : array_like, Quantity + Input data. + axis : int, optional + Axis along which to operate. By default flattened input is used. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the array. + + Returns + ------- + index_array : ndarray + An array of indices or a single index value. + """ + return func_array_manipulation(jnp.nanargmin, + a, + return_quantity=False, + axis=axis, + keepdims=keepdims) - Args: - a: array like, Quantity. - b: array like, Quantity. - axes: axes along which to compute the tensor dot product. - Returns: - tensor dot product of the two arrays. +@set_module_as('brainunit.math') +def frexp( + x: Union[Quantity, jax.typing.ArrayLike] +) -> Tuple[jax.Array, jax.Array]: + """ + Decompose the elements of x into mantissa and twos exponent. + + Returns (`mantissa`, `exponent`), where ``x = mantissa * 2**exponent``. + The mantissa lies in the open interval(-1, 1), while the twos + exponent is a signed integer. + + Parameters + ---------- + x : array_like, Quantity + Array of numbers to be decomposed. + + Returns + ------- + mantissa : ndarray + Floating values between -1 and 1. + This is a scalar if `x` is a scalar. + exponent : ndarray + Integer exponents of 2. + This is a scalar if `x` is a scalar. """ - return func_array_manipulation(jnp.nanargmin, a, return_quantity=False, axis=axis) + assert not isinstance(x, Quantity) or is_unitless(x), "Input must be unitless" + x = x.value if isinstance(x, Quantity) else x + return jnp.frexp(x) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 3c74dfb..85799ff 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -1659,7 +1659,7 @@ def test_compress(self): result = bu.math.compress(jnp.array([0, 1, 1, 0]), array) self.assertTrue(jnp.all(result == jnp.compress(jnp.array([0, 1, 1, 0]), array))) - q = [1, 2, 3, 4] * bu.second + q = jnp.array([1, 2, 3, 4]) a = [0, 1, 1, 0] * bu.second result_q = bu.math.compress(q, a) expected_q = jnp.compress(jnp.array([1, 2, 3, 4]), jnp.array([0, 1, 1, 0])) @@ -1785,10 +1785,10 @@ def test_extract(self): result = bu.math.extract(array > 1, array) self.assertTrue(jnp.all(result == jnp.extract(array > 1, array))) - q = [1, 2, 3] * bu.second + q = jnp.array([1, 2, 3]) a = array * bu.second - result_q = bu.math.extract(q > 1 * bu.second, a) - expected_q = jnp.extract(jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) + result_q = bu.math.extract(q > 1, a) + expected_q = jnp.extract(jnp.array([1, 2, 3]) > 1, jnp.array([1, 2, 3])) * bu.second assert jnp.all(result_q == expected_q) def test_count_nonzero(self): @@ -2251,7 +2251,8 @@ def test_broadcast_arrays(self): q2 = [[4], [5]] * bu.second result_q = bu.math.broadcast_arrays(q1, q2) expected_q = jnp.broadcast_arrays(jnp.array([1, 2, 3]), jnp.array([[4], [5]])) - assert_quantity(result_q, expected_q, bu.second) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, bu.second) def test_broadcast_shapes(self): a = jnp.array([1, 2, 3]) @@ -2277,6 +2278,13 @@ def test_einsum(self): expected_q = jnp.einsum('i,i->i', jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) assert_quantity(result_q, expected_q, bu.second) + q1 = [1, 2, 3] * bu.second + q2 = [1, 2, 3] * bu.volt + q3 = [1, 2, 3] * bu.ampere + result_q = bu.math.einsum('i,i,i->i', q1, q2, q3) + expected_q = jnp.einsum('i,i,i->i', jnp.array([1, 2, 3]), jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, bu.second * bu.volt * bu.ampere) + def test_gradient(self): f = jnp.array([1, 2, 4, 7, 11, 16], dtype=float) result = bu.math.gradient(f) diff --git a/docs/apis/brainunit.math.rst b/docs/apis/brainunit.math.rst index 7d3601d..60ec0a1 100644 --- a/docs/apis/brainunit.math.rst +++ b/docs/apis/brainunit.math.rst @@ -174,7 +174,6 @@ Functions Changing Unit nanvar cbrt square - frexp sqrt multiply divide @@ -390,6 +389,7 @@ More Functions nanargmin rot90 tensordot + frexp dtype e pi