From 72ccd90a20537051b2ba8f292951650d41c56b21 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 10 Jun 2024 23:39:46 +0800 Subject: [PATCH] Update _compat_numpy.py --- brainunit/math/_compat_numpy.py | 1756 ++++++++++++++++++++++++++----- 1 file changed, 1500 insertions(+), 256 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 09dda8f..0dae459 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -15,7 +15,7 @@ import functools from collections.abc import Sequence from functools import wraps -from typing import (Callable, Union, Optional, Any) +from typing import (Callable, Union, Optional, Any, List) import brainstate as bst import jax @@ -23,9 +23,11 @@ import numpy as np import opt_einsum from brainstate._utils import set_module_as +from jax import Array from jax._src.numpy.lax_numpy import _einsum from ._utils import _compatible_with_quantity +from .. import Quantity from .._base import (DIMENSIONLESS, Quantity, Unit, @@ -165,8 +167,8 @@ def f(*args, unit: Unit = None, **kwargs): # docs for full, eye, identity, tri, empty, ones, zeros full.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. -else return an array of `shape` filled with `fill_value`. + Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `shape` filled with `fill_value`. Args: shape: sequence of integers, describing the shape of the output array. @@ -183,8 +185,8 @@ def f(*args, unit: Unit = None, **kwargs): """ eye.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. -else return an identity matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. Args: n: the number of rows (and columns) in the output array. @@ -203,8 +205,8 @@ def f(*args, unit: Unit = None, **kwargs): """ identity.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. -else return an identity matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. Args: n: the number of rows (and columns) in the output array. @@ -220,8 +222,8 @@ def f(*args, unit: Unit = None, **kwargs): """ tri.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. -else return a triangular matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. + else return a triangular matrix of `shape`. Args: n: the number of rows in the output array. @@ -242,8 +244,8 @@ def f(*args, unit: Unit = None, **kwargs): # empty empty.__doc__ = """ -Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. -else return an array of `shape` with uninitialized values. + Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `shape` with uninitialized values. Args: shape: sequence of integers, describing the shape of the output array. @@ -260,8 +262,8 @@ def f(*args, unit: Unit = None, **kwargs): # ones ones.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. -else return an array of `shape` filled with 1. + Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. + else return an array of `shape` filled with 1. Args: shape: sequence of integers, describing the shape of the output array. @@ -278,8 +280,8 @@ def f(*args, unit: Unit = None, **kwargs): # zeros zeros.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. -else return an array of `shape` filled with 0. + Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. + else return an array of `shape` filled with 0. Args: shape: sequence of integers, describing the shape of the output array. @@ -296,8 +298,8 @@ def f(*args, unit: Unit = None, **kwargs): @set_module_as('brainunit.math') -def full_like(a: Union[Quantity, jax.Array, np.ndarray], - fill_value: Union[jax.Array, np.ndarray], +def full_like(a: Union[Quantity, bst.typing.ArrayLike], + fill_value: Union[bst.typing.ArrayLike], unit: Unit = None, dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None) -> Union[Quantity, jax.Array]: @@ -326,7 +328,7 @@ def full_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def diag(a: Union[Quantity, jax.Array, np.ndarray], +def diag(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -351,7 +353,7 @@ def diag(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def tril(a: Union[Quantity, jax.Array, np.ndarray], +def tril(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -376,7 +378,7 @@ def tril(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def triu(a: Union[Quantity, jax.Array, np.ndarray], +def triu(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -401,7 +403,7 @@ def triu(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def empty_like(a: Union[Quantity, jax.Array, np.ndarray], +def empty_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -429,7 +431,7 @@ def empty_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def ones_like(a: Union[Quantity, jax.Array, np.ndarray], +def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -457,7 +459,7 @@ def ones_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], +def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -486,7 +488,7 @@ def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') def asarray( - a: Union[Quantity, jax.Array, np.ndarray, Sequence[Quantity]], + a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], dtype: Optional[bst.typing.DTypeLike] = None, order: Optional[str] = None, unit: Optional[Unit] = None, @@ -606,12 +608,12 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start: Union[Quantity, jax.Array, np.ndarray], - stop: Union[Quantity, jax.Array, np.ndarray], +def linspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], num: int = 50, - endpoint: bool = True, - retstep: bool = False, - dtype: bst.typing.DTypeLike = None) -> Union[Quantity, jax.Array]: + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. @@ -642,12 +644,12 @@ def linspace(start: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def logspace(start: Union[Quantity, jax.Array, np.ndarray], - stop: Union[Quantity, jax.Array, np.ndarray], - num: int = 50, - endpoint: bool = True, - base: float = 10.0, - dtype: bst.typing.DTypeLike = None): +def logspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: Optional[int] = 50, + endpoint: Optional[bool] = True, + base: Optional[float] = 10.0, + dtype: Optional[bst.typing.DTypeLike] = None): ''' Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. @@ -678,10 +680,10 @@ def logspace(start: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], - val: Union[Quantity, jax.Array, np.ndarray], - wrap: bool = False, - inplace: bool = True) -> Union[Quantity, jax.Array]: +def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], + val: Union[Quantity, bst.typing.ArrayLike], + wrap: Optional[bool] = False, + inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -706,9 +708,9 @@ def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def array_split(ary: Union[Quantity, jax.Array, np.ndarray], - indices_or_sections: Union[int, jax.Array, np.ndarray], - axis: int = 0) -> Union[Quantity, jax.Array]: +def array_split(ary: Union[Quantity, bst.typing.ArrayLike], + indices_or_sections: Union[int, bst.typing.ArrayLike], + axis: Optional[int] = 0) -> list[Quantity] | list[Array]: ''' Split an array into multiple sub-arrays. @@ -721,18 +723,18 @@ def array_split(ary: Union[Quantity, jax.Array, np.ndarray], out: Quantity if `ary` is a Quantity, else an array. ''' if isinstance(ary, Quantity): - return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), unit=ary.unit) - elif isinstance(ary, (jax.Array, np.ndarray)): + return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + elif isinstance(ary, (bst.typing.ArrayLike)): return jnp.array_split(ary, indices_or_sections, axis) else: raise ValueError(f'Unsupported type: {type(ary)} for array_split') @set_module_as('brainunit.math') -def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], - copy: bool = True, - sparse: bool = False, - indexing: str = 'xy'): +def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy'): ''' Return coordinate matrices from coordinate vectors. @@ -756,9 +758,9 @@ def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def vander(x: Union[Quantity, jax.Array, np.ndarray], - N: bool=None, - increasing: bool=False) -> Union[Quantity, jax.Array]: +def vander(x: Union[Quantity, bst.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Generate a Vandermonde matrix. @@ -782,7 +784,16 @@ def vander(x: Union[Quantity, jax.Array, np.ndarray], # ----------------------- @set_module_as('brainunit.math') -def ndim(a): +def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: + ''' + Return the number of dimensions of an array. + + Args: + a: array_like, Quantity + + Returns: + out: int + ''' if isinstance(a, Quantity): return a.ndim else: @@ -790,7 +801,16 @@ def ndim(a): @set_module_as('brainunit.math') -def isreal(a): +def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return True if the input array is real. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isreal else: @@ -798,7 +818,16 @@ def isreal(a): @set_module_as('brainunit.math') -def isscalar(a): +def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: + ''' + Return True if the input is a scalar. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isscalar else: @@ -806,7 +835,16 @@ def isscalar(a): @set_module_as('brainunit.math') -def isfinite(a): +def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is finite or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isfinite else: @@ -814,7 +852,16 @@ def isfinite(a): @set_module_as('brainunit.math') -def isinf(a): +def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is infinite or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isinf else: @@ -822,7 +869,16 @@ def isinf(a): @set_module_as('brainunit.math') -def isnan(a): +def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is NaN or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isnan else: @@ -830,7 +886,7 @@ def isnan(a): @set_module_as('brainunit.math') -def shape(a): +def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: """ Return the shape of an array. @@ -870,7 +926,7 @@ def shape(a): @set_module_as('brainunit.math') -def size(a, axis=None): +def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: """ Return the number of elements along a given axis. @@ -963,276 +1019,1042 @@ def f(x, *args, **kwargs): diff = wrap_math_funcs_keep_unit_unary(jnp.diff) modf = wrap_math_funcs_keep_unit_unary(jnp.modf) +# docs for the functions above +real.__doc__ = ''' + Return the real part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -# math funcs keep unit (binary) -# ----------------------------- +imag.__doc__ = ''' + Return the imaginary part of the complex argument. -def wrap_math_funcs_keep_unit_binary(func): - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + Args: + x: array_like, Quantity - f.__module__ = 'brainunit.math' - return f + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +conj.__doc__ = ''' + Return the complex conjugate of the argument. -fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) -mod = wrap_math_funcs_keep_unit_binary(jnp.mod) -copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) -heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) -maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) -minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) -fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) -fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) -lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) -gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -# math funcs keep unit (n-ary) -# ---------------------------- -@set_module_as('brainunit.math') -def interp(x, xp, fp, left=None, right=None, period=None): - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value - else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result +conjugate.__doc__ = ''' + Return the complex conjugate of the argument. + Args: + x: array_like, Quantity -@set_module_as('brainunit.math') -def clip(a, a_min, a_max): - unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +negative.__doc__ = ''' + Return the negative of the argument. -# math funcs match unit (binary) -# ------------------------------ + Args: + x: array_like, Quantity -def wrap_math_funcs_match_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' - f.__module__ = 'brainunit.math' - return f +positive.__doc__ = ''' + Return the positive of the argument. + Args: + x: array_like, Quantity -add = wrap_math_funcs_match_unit_binary(jnp.add) -subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) -nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +abs.__doc__ = ''' + Return the absolute value of the argument. -# math funcs change unit (unary) -# ------------------------------ + Args: + x: array_like, Quantity -def wrap_math_funcs_change_unit_unary(func, change_unit_func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' - f.__module__ = 'brainunit.math' - return f +round_.__doc__ = ''' + Round an array to the nearest integer. + Args: + x: array_like, Quantity -reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +around.__doc__ = ''' + Round an array to the nearest integer. -@set_module_as('brainunit.math') -def prod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -@set_module_as('brainunit.math') -def nanprod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) +round.__doc__ = ''' + Round an array to the nearest integer. + Args: + x: array_like, Quantity -product = prod + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +rint.__doc__ = ''' + Round an array to the nearest integer. -@set_module_as('brainunit.math') -def cumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -@set_module_as('brainunit.math') -def nancumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) +floor.__doc__ = ''' + Return the floor of the argument. + Args: + x: array_like, Quantity -cumproduct = cumprod + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) -nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) -frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) -sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) -cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) -square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) +ceil.__doc__ = ''' + Return the ceiling of the argument. + Args: + x: array_like, Quantity -# math funcs change unit (binary) -# ------------------------------- + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -def wrap_math_funcs_change_unit_binary(func, change_unit_func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) - ) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) - elif isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') +trunc.__doc__ = ''' + Return the truncated value of the argument. - f.__module__ = 'brainunit.math' - return f + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) -divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) +fix.__doc__ = ''' + Return the nearest integer towards zero. + Args: + x: array_like, Quantity -@set_module_as('brainunit.math') -def power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y, *args, **kwargs), unit=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +sum.__doc__ = ''' + Return the sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nancumsum.__doc__ = ''' + Return the cumulative sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nansum.__doc__ = ''' + Return the sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +cumsum.__doc__ = ''' + Return the cumulative sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +ediff1d.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +absolute.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +fabs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +median.__doc__ = ''' + Return the median of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmin.__doc__ = ''' + Return the minimum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmax.__doc__ = ''' + Return the maximum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +ptp.__doc__ = ''' + Return the range of the array elements (maximum - minimum). + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +average.__doc__ = ''' + Return the weighted average of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +mean.__doc__ = ''' + Return the mean of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +std.__doc__ = ''' + Return the standard deviation of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmedian.__doc__ = ''' + Return the median of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmean.__doc__ = ''' + Return the mean of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanstd.__doc__ = ''' + Return the standard deviation of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +diff.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +modf.__doc__ = ''' + Return the fractional and integer parts of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity tuple if `x` is a Quantity, else an array tuple. +''' + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) +mod = wrap_math_funcs_keep_unit_binary(jnp.mod) +copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) +heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) +maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) +minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) +fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) +fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) +lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) +gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) + +# docs for the functions above +fmod.__doc__ = ''' + Return the element-wise remainder of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +mod.__doc__ = ''' + Return the element-wise modulus of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +copysign.__doc__ = ''' + Return a copy of the first array elements with the sign of the second array. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +heaviside.__doc__ = ''' + Compute the Heaviside step function. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +maximum.__doc__ = ''' + Element-wise maximum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +minimum.__doc__ = ''' + Element-wise minimum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmax.__doc__ = ''' + Element-wise maximum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmin.__doc__ = ''' + Element-wise minimum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +lcm.__doc__ = ''' + Return the least common multiple of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +gcd.__doc__ = ''' + Return the greatest common divisor of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('brainunit.math') +def interp(x: Union[Quantity, bst.typing.ArrayLike], + xp: Union[Quantity, bst.typing.ArrayLike], + fp: Union[Quantity, bst.typing.ArrayLike], + left: Union[Quantity, bst.typing.ArrayLike] = None, + right: Union[Quantity, bst.typing.ArrayLike] = None, + period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: + ''' + One-dimensional linear interpolation. + + Args: + x: array_like, Quantity + xp: array_like, Quantity + fp: array_like, Quantity + left: array_like, Quantity, optional + right: array_like, Quantity, optional + period: array_like, Quantity, optional + + Returns: + out: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +@set_module_as('brainunit.math') +def clip(a: Union[Quantity, bst.typing.ArrayLike], + a_min: Union[Quantity, bst.typing.ArrayLike], + a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Clip (limit) the values in an array. + + Args: + a: array_like, Quantity + a_min: array_like, Quantity + a_max: array_like, Quantity + + Returns: + out: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +add = wrap_math_funcs_match_unit_binary(jnp.add) +subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) +nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) + +# docs for the functions above +add.__doc__ = ''' + Add arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +subtract.__doc__ = ''' + Subtract arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +nextafter.__doc__ = ''' + Return the next floating-point value after `x1` towards `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(func, change_unit_func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) +reciprocal.__doc__ = ''' + Return the reciprocal of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + + +@set_module_as('brainunit.math') +def prod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None, + keepdims: Optional[bool] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None, + promote_integers: bool = True) -> Union[Quantity, jax.Array]: + ''' + Return the product of array elements over a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + promote_integers: bool, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + + +@set_module_as('brainunit.math') +def nanprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None, + keepdims: Optional[...] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None): + ''' + Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + + +product = prod + + +@set_module_as('brainunit.math') +def cumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('brainunit.math') +def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + +var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) +nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) +frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) +sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) +cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) +square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) + +# docs for the functions above +var.__doc__ = ''' + Compute the variance along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +nanvar.__doc__ = ''' + Compute the variance along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +frexp.__doc__ = ''' + Decompose a floating-point number into its mantissa and exponent. + + Args: + x: array_like, Quantity + + Returns: + out: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. +''' + +sqrt.__doc__ = ''' + Compute the square root of each element. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square root of the unit of `x`, else an array. +''' + +cbrt.__doc__ = ''' + Compute the cube root of each element. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the cube root of the unit of `x`, else an array. +''' + +square.__doc__ = ''' + Compute the square of each element. + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(func, change_unit_func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + +multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) +divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) +divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) +convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) + +# docs for the functions above +multiply.__doc__ = ''' + Multiply arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +divide.__doc__ = ''' + Divide arguments element-wise. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +cross.__doc__ = ''' + Return the cross product of two (arrays of) vectors. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +ldexp.__doc__ = ''' + Return x1 * 2**x2, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. +''' + +true_divide.__doc__ = ''' + Returns a true division of the inputs, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +divmod.__doc__ = ''' + Return element-wise quotient and remainder simultaneously. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +convolve.__doc__ = ''' + Returns the discrete, linear convolution of two one-dimensional sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' @set_module_as('brainunit.math') -def floor_divide(x, y, *args, **kwargs): +def power(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.floor_divide(x, y, *args, **kwargs) + return jnp.power(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y, *args, **kwargs), unit=x.unit / y)) + return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value, *args, **kwargs), unit=x / y.unit)) + return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') @set_module_as('brainunit.math') -def float_power(x, y, *args, **kwargs): +def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return the largest integer smaller or equal to the division of the inputs. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.float_power(x, y, *args, **kwargs) + return jnp.floor_divide(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y, *args, **kwargs), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('brainunit.math') +def float_power(x: Union[Quantity, bst.typing.ArrayLike], + y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + Args: + x: array_like, Quantity + y: array_like -divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.float_power(x, y) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') @set_module_as('brainunit.math') -def remainder(x, y, *args, **kwargs): +def remainder(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]): if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.remainder(x, y, *args, **kwargs) + return jnp.remainder(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y, *args, **kwargs), unit=x.unit % y)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value, *args, **kwargs), unit=x % y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') -convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) - - # math funcs only accept unitless (unary) # --------------------------------------- @@ -1282,6 +2104,297 @@ def f(x, *args, **kwargs): quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) +# docs for the functions above +exp.__doc__ = ''' + Calculate the exponential of all elements in the input array. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +exp2.__doc__ = ''' + Calculate 2 raised to the power of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +expm1.__doc__ = ''' + Calculate the exponential of the input elements minus 1. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log.__doc__ = ''' + Natural logarithm, element-wise. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log10.__doc__ = ''' + Base-10 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log1p.__doc__ = ''' + Natural logarithm of 1 + the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log2.__doc__ = ''' + Base-2 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arccos.__doc__ = ''' + Compute the arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arccosh.__doc__ = ''' + Compute the hyperbolic arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arcsin.__doc__ = ''' + Compute the arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arcsinh.__doc__ = ''' + Compute the hyperbolic arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arctan.__doc__ = ''' + Compute the arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arctanh.__doc__ = ''' + Compute the hyperbolic arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +cos.__doc__ = ''' + Compute the cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +cosh.__doc__ = ''' + Compute the hyperbolic cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sin.__doc__ = ''' + Compute the sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sinc.__doc__ = ''' + Compute the sinc function of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sinh.__doc__ = ''' + Compute the hyperbolic sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +tan.__doc__ = ''' + Compute the tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +tanh.__doc__ = ''' + Compute the hyperbolic tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +deg2rad.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +rad2deg.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +degrees.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +radians.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +angle.__doc__ = ''' + Return the angle of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +percentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +nanpercentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +quantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +nanquantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + # math funcs only accept unitless (binary) # ---------------------------------------- @@ -1316,6 +2429,51 @@ def f(x, y, *args, **kwargs): logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) +# docs for the functions above +hypot.__doc__ = ''' + Given the “legs” of a right triangle, return its hypotenuse. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +arctan2.__doc__ = ''' + Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + +logaddexp.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + +logaddexp2.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs in base-2. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + # math funcs remove unit (unary) # ------------------------------ @@ -1335,6 +2493,47 @@ def f(x, *args, **kwargs): histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) +# docs for the functions above +signbit.__doc__ = ''' + Returns element-wise True where signbit is set (less than zero). + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sign.__doc__ = ''' + Returns the sign of each element in the input array. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +histogram.__doc__ = ''' + Compute the histogram of a set of data. + + Args: + x: array_like, Quantity + + Returns: + out: Tuple of arrays (hist, bin_edges) +''' + +bincount.__doc__ = ''' + Count number of occurrences of each value in array of non-negative integers. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + # math funcs remove unit (binary) # ------------------------------- @@ -1358,6 +2557,51 @@ def f(x, y, *args, **kwargs): cov = wrap_math_funcs_remove_unit_binary(jnp.cov) digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) +# docs for the functions above +corrcoef.__doc__ = ''' + Return Pearson product-moment correlation coefficients. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +correlate.__doc__ = ''' + Cross-correlation of two sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +cov.__doc__ = ''' + Covariance matrix. + + Args: + x: array_like, Quantity + y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) + + Returns: + out: an array +''' + +digitize.__doc__ = ''' + Return the indices of the bins to which each value in input array belongs. + + Args: + x: array_like, Quantity + bins: array_like, Quantity + + Returns: + out: an array +''' + # array manipulation # ------------------ @@ -1751,8 +2995,8 @@ def einsum( @set_module_as('brainunit.math') def gradient( - f: Union[jax.Array, np.ndarray, Quantity], - *varargs: Union[jax.Array, np.ndarray, Quantity], + f: Union[bst.typing.ArrayLike, Quantity], + *varargs: Union[bst.typing.ArrayLike, Quantity], axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: @@ -1780,8 +3024,8 @@ def gradient( @set_module_as('brainunit.math') def intersect1d( - ar1: Union[jax.Array, np.ndarray], - ar2: Union[jax.Array, np.ndarray], + ar1: Union[bst.typing.ArrayLike], + ar2: Union[bst.typing.ArrayLike], assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: