From 19261936df2fd09f9bab2e9cbf1c2e35c5dbabd8 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 13 Jun 2024 16:10:41 +0800 Subject: [PATCH] Update _compat_numpy_funcs_accept_unitless.py --- .../math/_compat_numpy_array_manipulation.py | 14 +- .../_compat_numpy_funcs_accept_unitless.py | 553 ++++++++++++------ 2 files changed, 382 insertions(+), 185 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index 5650635..e482bd0 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -257,7 +257,7 @@ def concatenate( def stack( arrays: Union[Sequence[Array], Sequence[Quantity]], axis: int = 0, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, dtype: Optional[Any] = None ) -> Union[Array, Quantity]: """ @@ -905,7 +905,7 @@ def argsort( def max( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, where: Optional[Array] = None, @@ -945,7 +945,7 @@ def max( def min( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, where: Optional[Array] = None, @@ -985,7 +985,7 @@ def min( def choose( a: Union[Array, Quantity], choices: Sequence[Union[Array, Quantity]], - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, mode: str = 'raise', ) -> Union[Array, Quantity]: """ @@ -1043,7 +1043,7 @@ def compress( *, size: Optional[int] = None, fill_value: Optional[jax.typing.ArrayLike] = None, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, ) -> Union[Array, Quantity]: """ Return selected slices of a quantity or an array along given axis. @@ -1103,7 +1103,7 @@ def diagflat( def argmax( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, keepdims: Optional[bool] = None ) -> Array: """ @@ -1133,7 +1133,7 @@ def argmax( def argmin( a: Union[Array, Quantity], axis: Optional[int] = None, - out: Optional[Quantity, jax.typing.ArrayLike] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, keepdims: Optional[bool] = None ) -> Array: """ diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py index 4bb03e3..9256e89 100644 --- a/brainunit/math/_compat_numpy_funcs_accept_unitless.py +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import (Union) +from typing import (Union, Optional, Tuple) import jax import jax.numpy as jnp @@ -51,29 +51,37 @@ def funcs_only_accept_unitless_unary(func, x, *args, **kwargs): @set_module_as('brainunit.math') -def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ - Calculate the exponential of all elements in the input array. + Calculate the exponential of all elements in the input quantity or array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.exp, x) @set_module_as('brainunit.math') -def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: """ - Calculate 2 raised to the power of the input elements. + Calculate 2**p for all p in the input quantity or array. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.exp2, x) @@ -83,11 +91,15 @@ def expm1(x: Union[Array, Quantity]) -> Array: """ Calculate the exponential of the input elements minus 1. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.expm1, x) @@ -97,11 +109,15 @@ def log(x: Union[Array, Quantity]) -> Array: """ Natural logarithm, element-wise. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log, x) @@ -111,11 +127,15 @@ def log10(x: Union[Array, Quantity]) -> Array: """ Base-10 logarithm of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log10, x) @@ -125,11 +145,15 @@ def log1p(x: Union[Array, Quantity]) -> Array: """ Natural logarithm of 1 + the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log1p, x) @@ -139,11 +163,15 @@ def log2(x: Union[Array, Quantity]) -> Array: """ Base-2 logarithm of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.log2, x) @@ -153,11 +181,15 @@ def arccos(x: Union[Array, Quantity]) -> Array: """ Compute the arccosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arccos, x) @@ -167,11 +199,15 @@ def arccosh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arccosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arccosh, x) @@ -181,11 +217,15 @@ def arcsin(x: Union[Array, Quantity]) -> Array: """ Compute the arcsine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arcsin, x) @@ -195,11 +235,15 @@ def arcsinh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arcsine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arcsinh, x) @@ -209,11 +253,15 @@ def arctan(x: Union[Array, Quantity]) -> Array: """ Compute the arctangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arctan, x) @@ -223,11 +271,15 @@ def arctanh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic arctangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.arctanh, x) @@ -237,11 +289,15 @@ def cos(x: Union[Array, Quantity]) -> Array: """ Compute the cosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.cos, x) @@ -251,11 +307,15 @@ def cosh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic cosine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.cosh, x) @@ -265,11 +325,15 @@ def sin(x: Union[Array, Quantity]) -> Array: """ Compute the sine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sin, x) @@ -279,11 +343,15 @@ def sinc(x: Union[Array, Quantity]) -> Array: """ Compute the sinc function of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sinc, x) @@ -293,11 +361,15 @@ def sinh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic sine of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.sinh, x) @@ -307,11 +379,15 @@ def tan(x: Union[Array, Quantity]) -> Array: """ Compute the tangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.tan, x) @@ -321,11 +397,15 @@ def tanh(x: Union[Array, Quantity]) -> Array: """ Compute the hyperbolic tangent of the input elements. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.tanh, x) @@ -335,11 +415,15 @@ def deg2rad(x: Union[Array, Quantity]) -> Array: """ Convert angles from degrees to radians. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.deg2rad, x) @@ -349,11 +433,15 @@ def rad2deg(x: Union[Array, Quantity]) -> Array: """ Convert angles from radians to degrees. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.rad2deg, x) @@ -363,11 +451,15 @@ def degrees(x: Union[Array, Quantity]) -> Array: """ Convert angles from radians to degrees. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.degrees, x) @@ -377,11 +469,15 @@ def radians(x: Union[Array, Quantity]) -> Array: """ Convert angles from degrees to radians. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.radians, x) @@ -391,11 +487,15 @@ def angle(x: Union[Array, Quantity]) -> Array: """ Return the angle of the complex argument. - Args: - x: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_unary(jnp.angle, x) @@ -430,12 +530,17 @@ def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Given the “legs” of a right triangle, return its hypotenuse. - Args: - x: array_like, Quantity - y: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.hypot, x, y) @@ -445,12 +550,17 @@ def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.arctan2, x, y) @@ -460,12 +570,17 @@ def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Logarithm of the sum of exponentiations of the inputs. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.logaddexp, x, y) @@ -475,67 +590,149 @@ def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: """ Logarithm of the sum of exponentiations of the inputs in base-2. - Args: - x1: array_like, Quantity - x2: array_like, Quantity + Parameters + ---------- + x : array_like, Quantity + Input array or Quantity. + y : array_like, Quantity + Input array or Quantity. - Returns: - jax.Array: an array + Returns + ------- + out : jax.Array + Output array. """ return funcs_only_accept_unitless_binary(jnp.logaddexp2, x, y) @set_module_as('brainunit.math') -def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the nth percentile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.percentile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the nth percentile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.nanpercentile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the qth quantile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.quantile, a, q, *args, **kwargs) - - -@set_module_as('brainunit.math') -def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - """ - Compute the qth quantile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array - """ - return funcs_only_accept_unitless_binary(jnp.nanquantile, a, q, *args, **kwargs) +def percentile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, + interpolation: None = None, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis. + + Returns the q-th percentile(s) of the array elements. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + interpolation : str, optional + Deprecated name for the method keyword argument. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.percentile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims, interpolation=interpolation) + + +@set_module_as('brainunit.math') +def nanpercentile( + a: Union[Array, Quantity], + q: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int]]] = None, + out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None, + overwrite_input: Optional[bool] = None, + method: str = 'linear', + keepdims: Optional[bool] = False, + interpolation: None = None, +) -> Array: + """ + Compute the q-th percentile of the data along the specified axis, while ignoring nan values. + + Returns the q-th percentile(s) of the array elements, while ignoring nan values. + + Parameters + ---------- + a : array_like, Quantity + Input array or Quantity. + q : array_like, Quantity + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. + out : array_like, Quantity, optional + Alternative output array in which to place the result. + It must have the same shape and buffer length as the expected output but the type will be cast if necessary. + overwrite_input : bool, optional + If True, then allow the input array a to be modified by intermediate calculations, to save memory. + method : str, optional + This parameter specifies the method to use for estimating the + percentile. There are many different methods, some unique to NumPy. + See the notes for explanation. The options sorted by their R type + as summarized in the H&F paper [1]_ are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. NumPy further defines the + following discontinuous variations of the default 'linear' (7.) option: + + * 'lower' + * 'higher', + * 'midpoint' + * 'nearest' + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the result as dimensions with size one. + interpolation : str, optional + Deprecated name for the method keyword argument. + + Returns + ------- + out : jax.Array + Output array. + """ + return funcs_only_accept_unitless_binary(jnp.nanpercentile, a, q, axis=axis, out=out, overwrite_input=overwrite_input, + method=method, keepdims=keepdims, interpolation=interpolation) + + +quantile = percentile + +nanquantile = nanpercentile