diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index c4a7c26..4f9fdcf 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -14,7 +14,7 @@ # ============================================================================== from collections.abc import Sequence from functools import wraps -from typing import (Union, Optional, Tuple, List) +from typing import (Union, Optional, Tuple, List, Callable) import jax import jax.numpy as jnp @@ -41,255 +41,13 @@ # ------------------ -@_compatible_with_quantity() -def reshape(a: Union[Array, Quantity], shape: Union[int, Tuple[int, ...]], order: str = 'C') -> Union[Array, Quantity]: - return jnp.reshape(a, shape, order) - - -@_compatible_with_quantity() -def moveaxis(a: Union[Array, Quantity], source: Union[int, Tuple[int, ...]], - destination: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: - return jnp.moveaxis(a, source, destination) - - -@_compatible_with_quantity() -def transpose(a: Union[Array, Quantity], axes: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: - return jnp.transpose(a, axes) - - -@_compatible_with_quantity() -def swapaxes(a: Union[Array, Quantity], axis1: int, axis2: int) -> Union[Array, Quantity]: - return jnp.swapaxes(a, axis1, axis2) - - -@_compatible_with_quantity() -def concatenate(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: Optional[int] = None) -> Union[ - Array, Quantity]: - return jnp.concatenate(arrays, axis) - - -@_compatible_with_quantity() -def stack(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: int = 0) -> Union[Array, Quantity]: - return jnp.stack(arrays, axis) - - -@_compatible_with_quantity() -def vstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: - return jnp.vstack(arrays) - - -row_stack = vstack - - -@_compatible_with_quantity() -def hstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: - return jnp.hstack(arrays) - - -@_compatible_with_quantity() -def dstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: - return jnp.dstack(arrays) - - -@_compatible_with_quantity() -def column_stack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: - return jnp.column_stack(arrays) - - -@_compatible_with_quantity() -def split(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]], axis: int = 0) -> Union[ - List[Array], List[Quantity]]: - return jnp.split(a, indices_or_sections, axis) - - -@_compatible_with_quantity() -def dsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ - List[Array], List[Quantity]]: - return jnp.dsplit(a, indices_or_sections) - - -@_compatible_with_quantity() -def hsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ - List[Array], List[Quantity]]: - return jnp.hsplit(a, indices_or_sections) - - -@_compatible_with_quantity() -def vsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ - List[Array], List[Quantity]]: - return jnp.vsplit(a, indices_or_sections) - - -@_compatible_with_quantity() -def tile(A: Union[Array, Quantity], reps: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: - return jnp.tile(A, reps) - - -@_compatible_with_quantity() -def repeat(a: Union[Array, Quantity], repeats: Union[int, Tuple[int, ...]], axis: Optional[int] = None) -> Union[ - Array, Quantity]: - return jnp.repeat(a, repeats, axis) - - -@_compatible_with_quantity() -def unique(a: Union[Array, Quantity], return_index: bool = False, return_inverse: bool = False, - return_counts: bool = False, axis: Optional[int] = None) -> Union[Array, Quantity]: - return jnp.unique(a, return_index, return_inverse, return_counts, axis) - - -@_compatible_with_quantity() -def append(arr: Union[Array, Quantity], values: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ - Array, Quantity]: - return jnp.append(arr, values, axis) - - -@_compatible_with_quantity() -def flip(m: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: - return jnp.flip(m, axis) - - -@_compatible_with_quantity() -def fliplr(m: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.fliplr(m) - - -@_compatible_with_quantity() -def flipud(m: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.flipud(m) - - -@_compatible_with_quantity() -def roll(a: Union[Array, Quantity], shift: Union[int, Tuple[int, ...]], - axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: - return jnp.roll(a, shift, axis) - - -@_compatible_with_quantity() -def atleast_1d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.atleast_1d(*arys) - - -@_compatible_with_quantity() -def atleast_2d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.atleast_2d(*arys) - - -@_compatible_with_quantity() -def atleast_3d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.atleast_3d(*arys) - - -@_compatible_with_quantity() -def expand_dims(a: Union[Array, Quantity], axis: int) -> Union[Array, Quantity]: - return jnp.expand_dims(a, axis) - - -@_compatible_with_quantity() -def squeeze(a: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: - return jnp.squeeze(a, axis) - - -@_compatible_with_quantity() -def sort(a: Union[Array, Quantity], - axis: Optional[int] = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False, ) -> Union[Array, Quantity]: - return jnp.sort(a, axis, kind=kind, order=order, stable=stable, descending=descending) - - -@_compatible_with_quantity() -def argsort(a: Union[Array, Quantity], - axis: Optional[int] = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False, ) -> Array: - return jnp.argsort(a, axis, kind=kind, order=order, stable=stable, descending=descending) - - -@_compatible_with_quantity() -def max(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, - keepdims: bool = False) -> Union[Array, Quantity]: - return jnp.max(a, axis, out, keepdims) - - -@_compatible_with_quantity() -def min(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, - keepdims: bool = False) -> Union[Array, Quantity]: - return jnp.min(a, axis, out, keepdims) - - -@_compatible_with_quantity() -def choose(a: Union[Array, Quantity], choices: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: - return jnp.choose(a, choices) - - -@_compatible_with_quantity() -def block(arrays: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: - return jnp.block(arrays) - - -@_compatible_with_quantity() -def compress(condition: Union[Array, Quantity], a: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ - Array, Quantity]: - return jnp.compress(condition, a, axis) - - -@_compatible_with_quantity() -def diagflat(v: Union[Array, Quantity], k: int = 0) -> Union[Array, Quantity]: - return jnp.diagflat(v, k) - - -# return jax.numpy.Array, not Quantity - -@_compatible_with_quantity(return_quantity=False) -def argmax(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: - return jnp.argmax(a, axis, out) - - -@_compatible_with_quantity(return_quantity=False) -def argmin(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: - return jnp.argmin(a, axis, out) - - -@_compatible_with_quantity(return_quantity=False) -def argwhere(a: Union[Array, Quantity]) -> Array: - return jnp.argwhere(a) - - -@_compatible_with_quantity(return_quantity=False) -def nonzero(a: Union[Array, Quantity]) -> Tuple[Array, ...]: - return jnp.nonzero(a) - - -@_compatible_with_quantity(return_quantity=False) -def flatnonzero(a: Union[Array, Quantity]) -> Array: - return jnp.flatnonzero(a) - - -@_compatible_with_quantity(return_quantity=False) -def searchsorted(a: Union[Array, Quantity], v: Union[Array, Quantity], side: str = 'left', - sorter: Optional[Array] = None) -> Array: - return jnp.searchsorted(a, v, side, sorter) - - -@_compatible_with_quantity(return_quantity=False) -def extract(condition: Union[Array, Quantity], arr: Union[Array, Quantity]) -> Array: - return jnp.extract(condition, arr) - - -@_compatible_with_quantity(return_quantity=False) -def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Array: - return jnp.count_nonzero(a, axis) - - -amax = max -amin = min - -# docs for the functions above -reshape.__doc__ = ''' +@_compatible_with_quantity(jnp.reshape) +def reshape( + a: Union[Array, Quantity], + shape: Union[int, Tuple[int, ...]], + order: str = 'C' +) -> Union[Array, Quantity]: + ''' Return a reshaped copy of an array or a Quantity. Args: @@ -303,9 +61,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: reshaped copy of input array with the specified shape. -''' + ''' + ... + -moveaxis.__doc__ = ''' +@_compatible_with_quantity(jnp.moveaxis) +def moveaxis( + a: Union[Array, Quantity], + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]] +) -> Union[Array, Quantity]: + ''' Moves axes of an array to new positions. Other axes remain in their original order. Args: @@ -315,9 +81,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... -transpose.__doc__ = ''' + +@_compatible_with_quantity(jnp.transpose) +def transpose( + a: Union[Array, Quantity], + axes: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[Array, Quantity]: + ''' Returns a view of the array with axes transposed. Args: @@ -326,9 +99,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -swapaxes.__doc__ = ''' +@_compatible_with_quantity(jnp.swapaxes) +def swapaxes( + a: Union[Array, Quantity], axis1: int, axis2: int +) -> Union[Array, Quantity]: + ''' Interchanges two axes of an array. Args: @@ -338,9 +117,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... -concatenate.__doc__ = ''' + +@_compatible_with_quantity(jnp.concatenate) +def concatenate( + arrays: Union[Sequence[Array], + Sequence[Quantity]], axis: Optional[int] = None +) -> Union[Array, Quantity]: + ''' Join a sequence of arrays along an existing axis. Args: @@ -349,9 +135,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... + -stack.__doc__ = ''' +@_compatible_with_quantity(jnp.stack) +def stack( + arrays: Union[Sequence[Array], + Sequence[Quantity]], axis: int = 0 +) -> Union[Array, Quantity]: + ''' Join a sequence of arrays along a new axis. Args: @@ -360,9 +153,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... + -vstack.__doc__ = ''' +@_compatible_with_quantity(jnp.vstack) +def vstack( + arrays: Union[Sequence[Array], + Sequence[Quantity]] +) -> Union[Array, Quantity]: + ''' Stack arrays in sequence vertically (row wise). Args: @@ -370,9 +170,19 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array -''' + ''' + ... + + +row_stack = vstack -hstack.__doc__ = ''' + +@_compatible_with_quantity(jnp.hstack) +def hstack( + arrays: Union[Sequence[Array], + Sequence[Quantity]] +) -> Union[Array, Quantity]: + ''' Stack arrays in sequence horizontally (column wise). Args: @@ -380,9 +190,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... + -dstack.__doc__ = ''' +@_compatible_with_quantity(jnp.dstack) +def dstack( + arrays: Union[Sequence[Array], + Sequence[Quantity]] +) -> Union[Array, Quantity]: + ''' Stack arrays in sequence depth wise (along third axis). Args: @@ -390,9 +207,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... + -column_stack.__doc__ = ''' +@_compatible_with_quantity(jnp.column_stack) +def column_stack( + arrays: Union[Sequence[Array], + Sequence[Quantity]] +) -> Union[Array, Quantity]: + ''' Stack 1-D arrays as columns into a 2-D array. Args: @@ -400,9 +224,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... -split.__doc__ = ''' + +@_compatible_with_quantity(jnp.split) +def split( + a: Union[Array, Quantity], + indices_or_sections: Union[int, Sequence[int]], + axis: int = 0 +) -> Union[List[Array], List[Quantity]]: + ''' Split an array into multiple sub-arrays. Args: @@ -412,9 +244,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' + ''' + ... + -dsplit.__doc__ = ''' +@_compatible_with_quantity(jnp.dsplit) +def dsplit( + a: Union[Array, Quantity], + indices_or_sections: Union[int, Sequence[int]] +) -> Union[List[Array], List[Quantity]]: + ''' Split array along third axis (depth). Args: @@ -423,9 +262,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' + ''' + ... -hsplit.__doc__ = ''' + +@_compatible_with_quantity(jnp.hsplit) +def hsplit( + a: Union[Array, Quantity], + indices_or_sections: Union[int, Sequence[int]] +) -> Union[List[Array], List[Quantity]]: + ''' Split an array into multiple sub-arrays horizontally (column-wise). Args: @@ -434,9 +280,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' + ''' + ... + -vsplit.__doc__ = ''' +@_compatible_with_quantity(jnp.vsplit) +def vsplit( + a: Union[Array, Quantity], + indices_or_sections: Union[int, Sequence[int]] +) -> Union[List[Array], List[Quantity]]: + ''' Split an array into multiple sub-arrays vertically (row-wise). Args: @@ -445,9 +298,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' + ''' + ... + -tile.__doc__ = ''' +@_compatible_with_quantity(jnp.tile) +def tile( + A: Union[Array, Quantity], + reps: Union[int, Tuple[int, ...]] +) -> Union[Array, Quantity]: + ''' Construct an array by repeating A the number of times given by reps. Args: @@ -456,9 +316,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array -''' + ''' + ... -repeat.__doc__ = ''' + +@_compatible_with_quantity(jnp.repeat) +def repeat( + a: Union[Array, Quantity], + repeats: Union[int, Tuple[int, ...]], + axis: Optional[int] = None +) -> Union[Array, Quantity]: + ''' Repeat elements of an array. Args: @@ -468,9 +336,19 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -unique.__doc__ = ''' + ''' + ... + + +@_compatible_with_quantity(jnp.unique) +def unique( + a: Union[Array, Quantity], + return_index: bool = False, + return_inverse: bool = False, + return_counts: bool = False, + axis: Optional[int] = None +) -> Union[Array, Quantity]: + ''' Find the unique elements of an array. Args: @@ -482,9 +360,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -append.__doc__ = ''' +@_compatible_with_quantity(jnp.append) +def append( + arr: Union[Array, Quantity], + values: Union[Array, Quantity], + axis: Optional[int] = None +) -> Union[Array, Quantity]: + ''' Append values to the end of an array. Args: @@ -494,9 +380,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array -''' + ''' + ... -flip.__doc__ = ''' + +@_compatible_with_quantity(jnp.flip) +def flip( + m: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[Array, Quantity]: + ''' Reverse the order of elements in an array along the given axis. Args: @@ -505,9 +398,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' + ''' + ... + -fliplr.__doc__ = ''' +@_compatible_with_quantity(jnp.fliplr) +def fliplr( + m: Union[Array, Quantity] +) -> Union[Array, Quantity]: + ''' Flip array in the left/right direction. Args: @@ -515,9 +414,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' + ''' + ... + -flipud.__doc__ = ''' +@_compatible_with_quantity(jnp.flipud) +def flipud( + m: Union[Array, Quantity] +) -> Union[Array, Quantity]: + ''' Flip array in the up/down direction. Args: @@ -525,9 +430,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' + ''' + ... -roll.__doc__ = ''' + +@_compatible_with_quantity(jnp.roll) +def roll( + a: Union[Array, Quantity], + shift: Union[int, Tuple[int, ...]], + axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[Array, Quantity]: + ''' Roll array elements along a given axis. Args: @@ -537,9 +450,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -atleast_1d.__doc__ = ''' +@_compatible_with_quantity(jnp.atleast_1d) +def atleast_1d( + *arys: Union[Array, Quantity] +) -> Union[Array, Quantity]: + ''' View inputs as arrays with at least one dimension. Args: @@ -547,9 +466,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' + ''' + ... -atleast_2d.__doc__ = ''' + +@_compatible_with_quantity(jnp.atleast_2d) +def atleast_2d( + *arys: Union[Array, Quantity] +) -> Union[Array, Quantity]: + ''' View inputs as arrays with at least two dimensions. Args: @@ -557,9 +482,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' + ''' + ... + -atleast_3d.__doc__ = ''' +@_compatible_with_quantity(jnp.atleast_3d) +def atleast_3d( + *arys: Union[Array, Quantity] +) -> Union[Array, Quantity]: + ''' View inputs as arrays with at least three dimensions. Args: @@ -567,9 +498,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' + ''' + ... + -expand_dims.__doc__ = ''' +@_compatible_with_quantity(jnp.expand_dims) +def expand_dims( + a: Union[Array, Quantity], + axis: int +) -> Union[Array, Quantity]: + ''' Expand the shape of an array. Args: @@ -578,9 +516,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -squeeze.__doc__ = ''' +@_compatible_with_quantity(jnp.squeeze) +def squeeze( + a: Union[Array, Quantity], + axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[Array, Quantity]: + ''' Remove single-dimensional entries from the shape of an array. Args: @@ -589,9 +534,20 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -sort.__doc__ = ''' + ''' + ... + + +@_compatible_with_quantity(jnp.sort) +def sort( + a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, +) -> Union[Array, Quantity]: + ''' Return a sorted copy of an array. Args: @@ -599,11 +555,49 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra axis: int or None, optional kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional order: str or list of str, optional + stable: bool, optional + descending: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array + ''' + ... + + +@_compatible_with_quantity(jnp.argsort) +def argsort( + a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, +) -> Array: + ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + stable: bool, optional + descending: bool, optional Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' -max.__doc__ = ''' + ''' + ... + + +@_compatible_with_quantity(jnp.max) +def max( + a: Union[Array, Quantity], + axis: Optional[int] = None, + out: Optional[Array] = None, + keepdims: bool = False +) -> Union[Array, Quantity]: + ''' Return the maximum of an array or maximum along an axis. Args: @@ -613,9 +607,18 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -min.__doc__ = ''' + ''' + ... + + +@_compatible_with_quantity(jnp.min) +def min( + a: Union[Array, Quantity], + axis: Optional[int] = None, + out: Optional[Array] = None, + keepdims: bool = False +) -> Union[Array, Quantity]: + ''' Return the minimum of an array or minimum along an axis. Args: @@ -625,9 +628,16 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -choose.__doc__ = ''' +@_compatible_with_quantity(jnp.choose) +def choose( + a: Union[Array, Quantity], + choices: Sequence[Union[Array, Quantity]] +) -> Union[Array, Quantity]: + ''' Use an index array to construct a new array from a set of choices. Args: @@ -636,9 +646,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array -''' + ''' + ... + -block.__doc__ = ''' +@_compatible_with_quantity(jnp.block) +def block( + arrays: Sequence[Union[Array, Quantity]] +) -> Union[Array, Quantity]: + ''' Assemble an nd-array from nested lists of blocks. Args: @@ -646,9 +662,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' + ''' + ... -compress.__doc__ = ''' + +@_compatible_with_quantity(jnp.compress) +def compress( + condition: Union[Array, Quantity], + a: Union[Array, Quantity], + axis: Optional[int] = None +) -> Union[Array, Quantity]: + ''' Return selected slices of an array along given axis. Args: @@ -658,33 +682,37 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... + -diagflat.__doc__ = ''' +@_compatible_with_quantity(jnp.diagflat) +def diagflat( + v: Union[Array, Quantity], + k: int = 0 +) -> Union[Array, Quantity]: + ''' Create a two-dimensional array with the flattened input as a diagonal. Args: - a: array_like, Quantity - offset: int, optional + v: array_like, Quantity + k: int, optional Returns: Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' + ''' + ... -argsort.__doc__ = ''' - Returns the indices that would sort an array. - Args: - a: array_like, Quantity - axis: int or None, optional - kind: {'quicksort', 'mergesort', 'heapsort'}, optional - order: str or list of str, optional - - Returns: - jax.Array jax.numpy.Array (does not return a Quantity) -''' +# return jax.numpy.Array, not Quantity -argmax.__doc__ = ''' +@_compatible_with_quantity(jnp.argmax, return_quantity=False) +def argmax( + a: Union[Array, Quantity], + axis: Optional[int] = None, + out: Optional[Array] = None +) -> Array: + ''' Returns indices of the max value along an axis. Args: @@ -694,9 +722,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... -argmin.__doc__ = ''' + +@_compatible_with_quantity(jnp.argmin, return_quantity=False) +def argmin( + a: Union[Array, Quantity], + axis: Optional[int] = None, + out: Optional[Array] = None +) -> Array: + ''' Returns indices of the min value along an axis. Args: @@ -706,9 +742,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... + -argwhere.__doc__ = ''' +@_compatible_with_quantity(jnp.argwhere, return_quantity=False) +def argwhere( + a: Union[Array, Quantity] +) -> Array: + ''' Find indices of non-zero elements. Args: @@ -716,9 +758,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... + -nonzero.__doc__ = ''' +@_compatible_with_quantity(jnp.nonzero, return_quantity=False) +def nonzero( + a: Union[Array, Quantity] +) -> Tuple[Array, ...]: + ''' Return the indices of the elements that are non-zero. Args: @@ -726,9 +774,15 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... -flatnonzero.__doc__ = ''' + +@_compatible_with_quantity(jnp.flatnonzero, return_quantity=False) +def flatnonzero( + a: Union[Array, Quantity] +) -> Array: + ''' Return indices that are non-zero in the flattened version of a. Args: @@ -736,9 +790,17 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... + -searchsorted.__doc__ = ''' +@_compatible_with_quantity(jnp.searchsorted, return_quantity=False) +def searchsorted( + a: Union[Array, Quantity], v: Union[Array, Quantity], + side: str = 'left', + sorter: Optional[Array] = None +) -> Array: + ''' Find indices where elements should be inserted to maintain order. Args: @@ -748,20 +810,33 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... + -extract.__doc__ = ''' +@_compatible_with_quantity(jnp.extract, return_quantity=False) +def extract( + condition: Union[Array, Quantity], + arr: Union[Array, Quantity] +) -> Array: + ''' Return the elements of an array that satisfy some condition. Args: condition: array_like, Quantity - a: array_like, Quantity + arr: array_like, Quantity Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... + -count_nonzero.__doc__ = ''' +@_compatible_with_quantity(jnp.count_nonzero, return_quantity=False) +def count_nonzero( + a: Union[Array, Quantity], axis: Optional[int] = None +) -> Array: + ''' Counts the number of non-zero values in the array a. Args: @@ -770,33 +845,33 @@ def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Arra Returns: jax.Array: an array (does not return a Quantity) -''' + ''' + ... -def wrap_function_to_method(func): - @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.dim) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f +amax = max +amin = min -@wrap_function_to_method -def diagonal(a: Union[jax.Array, Quantity], offset: int = 0, axis1: int = 0, axis2: int = 1) -> Union[ - jax.Array, Quantity]: - return jnp.diagonal(a, offset, axis1, axis2) +def wrap_function_to_method(func: Callable): + @wraps(func) + def decorator(*args, **kwargs) -> Callable: + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + else: + return func(x, *args, **kwargs) + f.__module__ = 'brainunit.math' + return f -@wrap_function_to_method -def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Quantity]: - return jnp.ravel(a, order) + return decorator -diagonal.__doc__ = ''' +@wrap_function_to_method(jnp.diagonal) +def diagonal(a: Union[jax.Array, Quantity], offset: int = 0, axis1: int = 0, axis2: int = 1) -> Union[ + jax.Array, Quantity]: + ''' Return specified diagonals. Args: @@ -807,9 +882,13 @@ def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Q Returns: Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array -''' + ''' + ... -ravel.__doc__ = ''' + +@wrap_function_to_method(jnp.ravel) +def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Quantity]: + ''' Return a contiguous flattened array. Args: @@ -818,4 +897,5 @@ def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Q Returns: Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py index c87890a..316ccbd 100644 --- a/brainunit/math/_compat_numpy_funcs_accept_unitless.py +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -41,148 +41,27 @@ def wrap_math_funcs_only_accept_unitless_unary(func): @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - return func(jnp.array(x.value), *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_math_funcs_only_accept_unitless_unary -def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: - return jnp.exp(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: - return jnp.exp2(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def expm1(x: Union[Array, Quantity]) -> Array: - return jnp.expm1(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def log(x: Union[Array, Quantity]) -> Array: - return jnp.log(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def log10(x: Union[Array, Quantity]) -> Array: - return jnp.log10(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def log1p(x: Union[Array, Quantity]) -> Array: - return jnp.log1p(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def log2(x: Union[Array, Quantity]) -> Array: - return jnp.log2(x) - + def decorator(*args, **kwargs): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + return func(jnp.array(x.value), *args, **kwargs) + else: + return func(x, *args, **kwargs) -@wrap_math_funcs_only_accept_unitless_unary -def arccos(x: Union[Array, Quantity]) -> Array: - return jnp.arccos(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def arccosh(x: Union[Array, Quantity]) -> Array: - return jnp.arccosh(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def arcsin(x: Union[Array, Quantity]) -> Array: - return jnp.arcsin(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def arcsinh(x: Union[Array, Quantity]) -> Array: - return jnp.arcsinh(x) + f.__module__ = 'brainunit.math' + return f + return decorator -@wrap_math_funcs_only_accept_unitless_unary -def arctan(x: Union[Array, Quantity]) -> Array: - return jnp.arctan(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def arctanh(x: Union[Array, Quantity]) -> Array: - return jnp.arctanh(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def cos(x: Union[Array, Quantity]) -> Array: - return jnp.cos(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def cosh(x: Union[Array, Quantity]) -> Array: - return jnp.cosh(x) - -@wrap_math_funcs_only_accept_unitless_unary -def sin(x: Union[Array, Quantity]) -> Array: - return jnp.sin(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def sinc(x: Union[Array, Quantity]) -> Array: - return jnp.sinc(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def sinh(x: Union[Array, Quantity]) -> Array: - return jnp.sinh(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def tan(x: Union[Array, Quantity]) -> Array: - return jnp.tan(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def tanh(x: Union[Array, Quantity]) -> Array: - return jnp.tanh(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def deg2rad(x: Union[Array, Quantity]) -> Array: - return jnp.deg2rad(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def rad2deg(x: Union[Array, Quantity]) -> Array: - return jnp.rad2deg(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def degrees(x: Union[Array, Quantity]) -> Array: - return jnp.degrees(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def radians(x: Union[Array, Quantity]) -> Array: - return jnp.radians(x) - - -@wrap_math_funcs_only_accept_unitless_unary -def angle(x: Union[Array, Quantity]) -> Array: - return jnp.angle(x) - - -# docs for the functions above -exp.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.exp) +def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + ''' Calculate the exponential of all elements in the input array. Args: @@ -190,9 +69,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -exp2.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.exp2) +def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + ''' Calculate 2 raised to the power of the input elements. Args: @@ -200,9 +83,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -expm1.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.expm1) +def expm1(x: Union[Array, Quantity]) -> Array: + ''' Calculate the exponential of the input elements minus 1. Args: @@ -210,9 +97,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -log.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.log) +def log(x: Union[Array, Quantity]) -> Array: + ''' Natural logarithm, element-wise. Args: @@ -220,9 +111,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -log10.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.log10) +def log10(x: Union[Array, Quantity]) -> Array: + ''' Base-10 logarithm of the input elements. Args: @@ -230,9 +125,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -log1p.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.log1p) +def log1p(x: Union[Array, Quantity]) -> Array: + ''' Natural logarithm of 1 + the input elements. Args: @@ -240,9 +139,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -log2.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.log2) +def log2(x: Union[Array, Quantity]) -> Array: + ''' Base-2 logarithm of the input elements. Args: @@ -250,9 +153,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -arccos.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.arccos) +def arccos(x: Union[Array, Quantity]) -> Array: + ''' Compute the arccosine of the input elements. Args: @@ -260,9 +167,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -arccosh.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.arccosh) +def arccosh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic arccosine of the input elements. Args: @@ -270,9 +181,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -arcsin.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.arcsin) +def arcsin(x: Union[Array, Quantity]) -> Array: + ''' Compute the arcsine of the input elements. Args: @@ -280,9 +195,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -arcsinh.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.arcsinh) +def arcsinh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic arcsine of the input elements. Args: @@ -290,9 +209,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -arctan.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.arctan) +def arctan(x: Union[Array, Quantity]) -> Array: + ''' Compute the arctangent of the input elements. Args: @@ -300,9 +223,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -arctanh.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.arctanh) +def arctanh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic arctangent of the input elements. Args: @@ -310,9 +237,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -cos.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.cos) +def cos(x: Union[Array, Quantity]) -> Array: + ''' Compute the cosine of the input elements. Args: @@ -320,9 +251,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -cosh.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.cosh) +def cosh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic cosine of the input elements. Args: @@ -330,9 +265,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -sin.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.sin) +def sin(x: Union[Array, Quantity]) -> Array: + ''' Compute the sine of the input elements. Args: @@ -340,9 +279,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -sinc.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.sinc) +def sinc(x: Union[Array, Quantity]) -> Array: + ''' Compute the sinc function of the input elements. Args: @@ -350,9 +293,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -sinh.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.sinh) +def sinh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic sine of the input elements. Args: @@ -360,9 +307,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -tan.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.tan) +def tan(x: Union[Array, Quantity]) -> Array: + ''' Compute the tangent of the input elements. Args: @@ -370,9 +321,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -tanh.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.tanh) +def tanh(x: Union[Array, Quantity]) -> Array: + ''' Compute the hyperbolic tangent of the input elements. Args: @@ -380,9 +335,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... + -deg2rad.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_unary(jnp.deg2rad) +def deg2rad(x: Union[Array, Quantity]) -> Array: + ''' Convert angles from degrees to radians. Args: @@ -390,9 +349,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -rad2deg.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.rad2deg) +def rad2deg(x: Union[Array, Quantity]) -> Array: + ''' Convert angles from radians to degrees. Args: @@ -400,9 +363,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -degrees.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.degrees) +def degrees(x: Union[Array, Quantity]) -> Array: + ''' Convert angles from radians to degrees. Args: @@ -410,9 +377,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -radians.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.radians) +def radians(x: Union[Array, Quantity]) -> Array: + ''' Convert angles from degrees to radians. Args: @@ -420,9 +391,13 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -angle.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_unary(jnp.angle) +def angle(x: Union[Array, Quantity]) -> Array: + ''' Return the angle of the complex argument. Args: @@ -430,7 +405,8 @@ def angle(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... # math funcs only accept unitless (binary) @@ -438,72 +414,36 @@ def angle(x: Union[Array, Quantity]) -> Array: def wrap_math_funcs_only_accept_unitless_binary(func): @wraps(func) - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - fail_for_dimension_mismatch( - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=y, - ) - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_math_funcs_only_accept_unitless_binary + def decorator(*args, **kwargs): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + fail_for_dimension_mismatch( + y, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=y, + ) + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_math_funcs_only_accept_unitless_binary(jnp.hypot) def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.hypot(x, y) - - -@wrap_math_funcs_only_accept_unitless_binary -def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.arctan2(x, y) - - -@wrap_math_funcs_only_accept_unitless_binary -def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.logaddexp(x, y) - - -@wrap_math_funcs_only_accept_unitless_binary -def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.logaddexp2(x, y) - - -@wrap_math_funcs_only_accept_unitless_binary -def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - return jnp.percentile(a, q, *args, **kwargs) - - -@wrap_math_funcs_only_accept_unitless_binary -def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - return jnp.nanpercentile(a, q, *args, **kwargs) - - -@wrap_math_funcs_only_accept_unitless_binary -def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - return jnp.quantile(a, q, *args, **kwargs) - - -@wrap_math_funcs_only_accept_unitless_binary -def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: - return jnp.nanquantile(a, q, *args, **kwargs) - - -# docs for the functions above -hypot.__doc__ = ''' + ''' Given the “legs” of a right triangle, return its hypotenuse. Args: @@ -512,9 +452,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... -arctan2.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_binary(jnp.arctan2) +def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + ''' Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. Args: @@ -523,9 +467,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... + -logaddexp.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) +def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + ''' Logarithm of the sum of exponentiations of the inputs. Args: @@ -534,9 +482,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... + -logaddexp2.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) +def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + ''' Logarithm of the sum of exponentiations of the inputs in base-2. Args: @@ -545,9 +497,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... + -percentile.__doc__ = ''' +@wrap_math_funcs_only_accept_unitless_binary(jnp.percentile) +def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + ''' Compute the nth percentile of the input array along the specified axis. Args: @@ -555,9 +511,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... -nanpercentile.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_binary(jnp.nanpercentile) +def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + ''' Compute the nth percentile of the input array along the specified axis, ignoring NaNs. Args: @@ -565,9 +525,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... -quantile.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_binary(jnp.quantile) +def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + ''' Compute the qth quantile of the input array along the specified axis. Args: @@ -575,9 +539,13 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... -nanquantile.__doc__ = ''' + +@wrap_math_funcs_only_accept_unitless_binary(jnp.nanquantile) +def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + ''' Compute the qth quantile of the input array along the specified axis, ignoring NaNs. Args: @@ -585,4 +553,5 @@ def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **k Returns: jax.Array: an array -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index 1325539..8d0e016 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -89,45 +89,27 @@ def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: def wrap_elementwise_bit_operation_binary(func): @wraps(func) - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) or isinstance(y, Quantity): - raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -@wrap_elementwise_bit_operation_binary -def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: - return jnp.bitwise_and(x, y) - - -@wrap_elementwise_bit_operation_binary -def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: - return jnp.bitwise_or(x, y) - - -@wrap_elementwise_bit_operation_binary -def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: - return jnp.bitwise_xor(x, y) - - -@wrap_elementwise_bit_operation_binary -def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: - return jnp.left_shift(x, y) - - -@wrap_elementwise_bit_operation_binary -def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: - return jnp.right_shift(x, y) - - -# docs for functions above -bitwise_and.__doc__ = ''' + def decorator(*args, **kwargs): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) or isinstance(y, Quantity): + raise ValueError(f'Expected integers, got {x} and {y}') + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_elementwise_bit_operation_binary(jnp.bitwise_and) +def bitwise_and( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Array: + ''' Compute the bit-wise AND of two arrays element-wise. Args: @@ -136,9 +118,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: jax.Array: an array -''' + ''' + ... + -bitwise_or.__doc__ = ''' +@wrap_elementwise_bit_operation_binary(jnp.bitwise_or) +def bitwise_or( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Array: + ''' Compute the bit-wise OR of two arrays element-wise. Args: @@ -147,9 +136,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: jax.Array: an array -''' + ''' + ... -bitwise_xor.__doc__ = ''' + +@wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) +def bitwise_xor( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Array: + ''' Compute the bit-wise XOR of two arrays element-wise. Args: @@ -158,9 +154,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: jax.Array: an array -''' + ''' + ... + -left_shift.__doc__ = ''' +@wrap_elementwise_bit_operation_binary(jnp.left_shift) +def left_shift( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Array: + ''' Shift the bits of an integer to the left. Args: @@ -169,9 +172,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: jax.Array: an array -''' + ''' + ... + -right_shift.__doc__ = ''' +@wrap_elementwise_bit_operation_binary(jnp.right_shift) +def right_shift( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Array: + ''' Shift the bits of an integer to the right. Args: @@ -180,4 +190,5 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: jax.Array: an array -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index 227234c..8f81a43 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -44,9 +44,13 @@ # math funcs change unit (unary) # ------------------------------ -def wrap_math_funcs_change_unit_unary(change_unit_func: Callable) -> Callable: - def decorator(func: Callable) -> Callable: - @wraps(func) +def wrap_math_funcs_change_unit_unary( + func: Callable, + change_unit_func: Callable +) -> Callable: + @wraps(func) + def decorator(*args, **kwargs) -> Callable: + def f(x, *args, **kwargs): if isinstance(x, Quantity): return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.dim))) @@ -61,50 +65,11 @@ def f(x, *args, **kwargs): return decorator -@wrap_math_funcs_change_unit_unary(lambda x: x ** -1) -def reciprocal(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.reciprocal(x) - - -@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def var(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[Union[int, Sequence[int]]] = None, - ddof: int = 0, - keepdims: bool = False) -> Union[Quantity, jax.Array]: - return jnp.var(x, axis=axis, ddof=ddof, keepdims=keepdims) - - -@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def nanvar(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[Union[int, Sequence[int]]] = None, - ddof: int = 0, - keepdims: bool = False) -> Union[Quantity, jax.Array]: - return jnp.nanvar(x, axis=axis, ddof=ddof, keepdims=keepdims) - - -@wrap_math_funcs_change_unit_unary(lambda x: x * 2 ** -1) -def frexp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.frexp(x) - - -@wrap_math_funcs_change_unit_unary(lambda x: x ** 0.5) -def sqrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.sqrt(x) - - -@wrap_math_funcs_change_unit_unary(lambda x: x ** (1 / 3)) -def cbrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.cbrt(x) - - -@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.square(x) - - -# docs for the functions above - -reciprocal.__doc__ = ''' +@wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) +def reciprocal( + x: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + ''' Return the reciprocal of the argument. Args: @@ -112,9 +77,18 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -var.__doc__ = ''' + +@wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) +def var( + x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False +) -> Union[Quantity, jax.Array]: + ''' Compute the variance along the specified axis. Args: @@ -122,9 +96,18 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' + ''' + ... -nanvar.__doc__ = ''' + +@wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) +def nanvar( + x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False +) -> Union[Quantity, jax.Array]: + ''' Compute the variance along the specified axis, ignoring NaNs. Args: @@ -132,9 +115,15 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' + ''' + ... + -frexp.__doc__ = ''' +@wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x: x * 2 ** -1) +def frexp( + x: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + ''' Decompose a floating-point number into its mantissa and exponent. Args: @@ -142,9 +131,15 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. -''' + ''' + ... -sqrt.__doc__ = ''' + +@wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) +def sqrt( + x: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + ''' Compute the square root of each element. Args: @@ -152,9 +147,15 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. -''' + ''' + ... -cbrt.__doc__ = ''' + +@wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) +def cbrt( + x: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + ''' Compute the cube root of each element. Args: @@ -162,9 +163,15 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. -''' + ''' + ... + -square.__doc__ = ''' +@wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) +def square( + x: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: + ''' Compute the square of each element. Args: @@ -172,7 +179,8 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' + ''' + ... @set_module_as('brainunit.math') @@ -292,9 +300,12 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], # math funcs change unit (binary) # ------------------------------- -def wrap_math_funcs_change_unit_binary(change_unit_func): - def decorator(func: Callable) -> Callable: - @wraps(func) +def wrap_math_funcs_change_unit_binary( + func: Callable, + change_unit_func: Callable +): + @wraps(func) + def decorator(*args, **kwargs) -> Callable: def f(x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless( @@ -313,46 +324,16 @@ def f(x, y, *args, **kwargs): f.__module__ = 'brainunit.math' return f - return decorator - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def multiply(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.multiply(x, y) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.divide(x, y) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def cross(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.cross(x, y) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * 2 ** y) -def ldexp(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.ldexp(x, y) - -@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def true_divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.true_divide(x, y) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def divmod(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.divmod(x, y) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): - return jnp.convolve(x, y) + return decorator -# docs for the functions above -multiply.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) +def multiply( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Multiply arguments element-wise. Args: @@ -361,9 +342,16 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... + -divide.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) +def divide( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Divide arguments element-wise. Args: @@ -371,9 +359,16 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... + -cross.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) +def cross( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Return the cross product of two (arrays of) vectors. Args: @@ -382,9 +377,16 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... -ldexp.__doc__ = ''' + +@wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) +def ldexp( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Return x1 * 2**x2, element-wise. Args: @@ -392,10 +394,17 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty y: array_like, Quantity Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. -''' + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. + ''' + ... + -true_divide.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) +def true_divide( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Returns a true division of the inputs, element-wise. Args: @@ -404,9 +413,16 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... -divmod.__doc__ = ''' + +@wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) +def divmod( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Return element-wise quotient and remainder simultaneously. Args: @@ -415,9 +431,16 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... + -convolve.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) +def convolve( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[Quantity, bst.typing.ArrayLike]: + ''' Returns the discrete, linear convolution of two one-dimensional sequences. Args: @@ -426,7 +449,8 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 4a6616e..d3152af 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== from functools import wraps -from typing import (Union) +from typing import (Union, Callable) import brainstate as bst import jax @@ -45,192 +45,26 @@ # math funcs keep unit (unary) # ---------------------------- -def wrap_math_funcs_keep_unit_unary(func): +def wrap_math_funcs_keep_unit_unary(func: Callable): @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.dim) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + def decorator(*args, **kwargs) -> Callable: + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - f.__module__ = 'brainunit.math' - return f + f.__module__ = 'brainunit.math' + return f + return decorator -@wrap_math_funcs_keep_unit_unary -def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.real(x) - - -@wrap_math_funcs_keep_unit_unary -def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.imag(x) - - -@wrap_math_funcs_keep_unit_unary -def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.conj(x) - - -@wrap_math_funcs_keep_unit_unary -def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.conjugate(x) - - -@wrap_math_funcs_keep_unit_unary -def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.negative(x) - - -@wrap_math_funcs_keep_unit_unary -def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.positive(x) - - -@wrap_math_funcs_keep_unit_unary -def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.abs(x) - - -@wrap_math_funcs_keep_unit_unary -def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.round(x) - - -@wrap_math_funcs_keep_unit_unary -def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.around(x) - - -@wrap_math_funcs_keep_unit_unary -def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.round(x) - - -@wrap_math_funcs_keep_unit_unary -def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.rint(x) - - -@wrap_math_funcs_keep_unit_unary -def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.floor(x) - - -@wrap_math_funcs_keep_unit_unary -def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.ceil(x) - - -@wrap_math_funcs_keep_unit_unary -def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.trunc(x) - - -@wrap_math_funcs_keep_unit_unary -def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.fix(x) - - -@wrap_math_funcs_keep_unit_unary -def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.sum(x) - - -@wrap_math_funcs_keep_unit_unary -def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nancumsum(x) - - -@wrap_math_funcs_keep_unit_unary -def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nansum(x) - - -@wrap_math_funcs_keep_unit_unary -def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.cumsum(x) - - -@wrap_math_funcs_keep_unit_unary -def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.ediff1d(x) - - -@wrap_math_funcs_keep_unit_unary -def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.absolute(x) - - -@wrap_math_funcs_keep_unit_unary -def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.fabs(x) - - -@wrap_math_funcs_keep_unit_unary -def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.median(x) - - -@wrap_math_funcs_keep_unit_unary -def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nanmin(x) - - -@wrap_math_funcs_keep_unit_unary -def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nanmax(x) - - -@wrap_math_funcs_keep_unit_unary -def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.ptp(x) - -@wrap_math_funcs_keep_unit_unary -def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.average(x) - - -@wrap_math_funcs_keep_unit_unary -def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.mean(x) - - -@wrap_math_funcs_keep_unit_unary -def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.std(x) - - -@wrap_math_funcs_keep_unit_unary -def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nanmedian(x) - - -@wrap_math_funcs_keep_unit_unary -def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nanmean(x) - - -@wrap_math_funcs_keep_unit_unary -def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.nanstd(x) - - -@wrap_math_funcs_keep_unit_unary -def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.diff(x) - - -@wrap_math_funcs_keep_unit_unary -def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - return jnp.modf(x) - - -# docs for the functions above -real.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.real) +def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the real part of the complex argument. Args: @@ -238,9 +72,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -imag.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.imag) +def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the imaginary part of the complex argument. Args: @@ -248,9 +86,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -conj.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.conj) +def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the complex conjugate of the argument. Args: @@ -258,9 +100,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -conjugate.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.conjugate) +def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the complex conjugate of the argument. Args: @@ -268,9 +114,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -negative.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.negative) +def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the negative of the argument. Args: @@ -278,9 +128,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -positive.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.positive) +def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the positive of the argument. Args: @@ -288,9 +142,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -abs.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.abs) +def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the absolute value of the argument. Args: @@ -298,9 +156,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -round_.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.round_) +def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Round an array to the nearest integer. Args: @@ -308,9 +170,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -around.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.around) +def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Round an array to the nearest integer. Args: @@ -318,9 +184,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -round.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.round) +def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Round an array to the nearest integer. Args: @@ -328,9 +198,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -rint.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.rint) +def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Round an array to the nearest integer. Args: @@ -338,9 +212,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -floor.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.floor) +def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the floor of the argument. Args: @@ -348,9 +226,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -ceil.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.ceil) +def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the ceiling of the argument. Args: @@ -358,9 +240,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -trunc.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.trunc) +def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the truncated value of the argument. Args: @@ -368,9 +254,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -fix.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.fix) +def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the nearest integer towards zero. Args: @@ -378,9 +268,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -sum.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.sum) +def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the sum of the array elements. Args: @@ -388,9 +282,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nancumsum.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nancumsum) +def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the cumulative sum of the array elements, ignoring NaNs. Args: @@ -398,9 +296,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nansum.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nansum) +def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the sum of the array elements, ignoring NaNs. Args: @@ -408,9 +310,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -cumsum.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.cumsum) +def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the cumulative sum of the array elements. Args: @@ -418,9 +324,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -ediff1d.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.ediff1d) +def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the differences between consecutive elements of the array. Args: @@ -428,9 +338,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -absolute.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.absolute) +def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the absolute value of the argument. Args: @@ -438,9 +352,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -fabs.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.fabs) +def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the absolute value of the argument. Args: @@ -448,9 +366,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -median.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.median) +def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the median of the array elements. Args: @@ -458,9 +380,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nanmin.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nanmin) +def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the minimum of the array elements, ignoring NaNs. Args: @@ -468,9 +394,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nanmax.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nanmax) +def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the maximum of the array elements, ignoring NaNs. Args: @@ -478,9 +408,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -ptp.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.ptp) +def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the range of the array elements (maximum - minimum). Args: @@ -488,9 +422,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -average.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.average) +def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the weighted average of the array elements. Args: @@ -498,9 +436,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -mean.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.mean) +def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the mean of the array elements. Args: @@ -508,9 +450,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -std.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.std) +def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the standard deviation of the array elements. Args: @@ -518,9 +464,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nanmedian.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nanmedian) +def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the median of the array elements, ignoring NaNs. Args: @@ -528,9 +478,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -nanmean.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nanmean) +def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the mean of the array elements, ignoring NaNs. Args: @@ -538,9 +492,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -nanstd.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.nanstd) +def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the standard deviation of the array elements, ignoring NaNs. Args: @@ -548,9 +506,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... -diff.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.diff) +def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the differences between consecutive elements of the array. Args: @@ -558,9 +520,13 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' + ''' + ... + -modf.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.modf) +def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' Return the fractional and integer parts of the array elements. Args: @@ -568,7 +534,8 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] Returns: Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. -''' + ''' + ... # math funcs keep unit (binary) @@ -576,70 +543,24 @@ def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array] def wrap_math_funcs_keep_unit_binary(func): @wraps(func) - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -@wrap_math_funcs_keep_unit_binary -def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.fmod(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.mod(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.copysign(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.heaviside(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.maximum(x1, x2) + def decorator(*args, **kwargs): + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + f.__module__ = 'brainunit.math' + return f -@wrap_math_funcs_keep_unit_binary -def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.minimum(x1, x2) - + return decorator -@wrap_math_funcs_keep_unit_binary -def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.fmax(x1, x2) - -@wrap_math_funcs_keep_unit_binary -def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.fmin(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.lcm(x1, x2) - - -@wrap_math_funcs_keep_unit_binary -def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: - return jnp.gcd(x1, x2) - - -# docs for the functions above -fmod.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.fmod) +def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Return the element-wise remainder of division. Args: @@ -648,9 +569,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -mod.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.mod) +def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Return the element-wise modulus of division. Args: @@ -659,9 +584,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -copysign.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.copysign) +def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Return a copy of the first array elements with the sign of the second array. Args: @@ -670,9 +599,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... -heaviside.__doc__ = ''' + +@wrap_math_funcs_keep_unit_binary(jnp.heaviside) +def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Compute the Heaviside step function. Args: @@ -681,9 +614,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -maximum.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.maximum) +def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Element-wise maximum of array elements. Args: @@ -692,9 +629,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -minimum.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.minimum) +def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Element-wise minimum of array elements. Args: @@ -703,9 +644,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... -fmax.__doc__ = ''' + +@wrap_math_funcs_keep_unit_binary(jnp.fmax) +def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Element-wise maximum of array elements ignoring NaNs. Args: @@ -714,9 +659,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -fmin.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.fmin) +def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Element-wise minimum of array elements ignoring NaNs. Args: @@ -725,9 +674,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... + -lcm.__doc__ = ''' +@wrap_math_funcs_keep_unit_binary(jnp.lcm) +def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Return the least common multiple of `x1` and `x2`. Args: @@ -736,9 +689,13 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... -gcd.__doc__ = ''' + +@wrap_math_funcs_keep_unit_binary(jnp.gcd) +def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + ''' Return the greatest common divisor of `x1` and `x2`. Args: @@ -747,7 +704,8 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... # math funcs keep unit (n-ary) diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py index e7d69e7..34cb777 100644 --- a/brainunit/math/_compat_numpy_funcs_logic.py +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -41,42 +41,30 @@ def wrap_logic_func_unary(func): @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected booleans, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -@wrap_logic_func_unary -def all(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, - out: Optional[Array] = None, keepdims: bool = False, - where: Optional[Array] = None) -> Union[bool, Array]: - return jnp.all(x, axis=axis, out=out, keepdims=keepdims, where=where) - - -@wrap_logic_func_unary -def any(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, - out: Optional[Array] = None, keepdims: bool = False, - where: Optional[Array] = None) -> Union[bool, Array]: - return jnp.any(x, axis=axis, out=out, keepdims=keepdims, where=where) - - -@wrap_logic_func_unary -def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.logical_not(x) - - -alltrue = all -sometrue = any - -# docs for functions above -all.__doc__ = ''' + def decorator(*args, **kwargs): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected booleans, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_logic_func_unary(jnp.all) +def all( + x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + out: Optional[Array] = None, + keepdims: bool = False, + where: Optional[Array] = None +) -> Union[bool, Array]: + ''' Test whether all array elements along a given axis evaluate to True. Args: @@ -88,9 +76,19 @@ def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: Returns: Union[bool, jax.Array]: bool or array -''' - -any.__doc__ = ''' + ''' + ... + + +@wrap_logic_func_unary(jnp.any) +def any( + x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + out: Optional[Array] = None, + keepdims: bool = False, + where: Optional[Array] = None +) -> Union[bool, Array]: + ''' Test whether any array element along a given axis evaluates to True. Args: @@ -102,9 +100,13 @@ def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... + -logical_not.__doc__ = ''' +@wrap_logic_func_unary(jnp.logical_not) +def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + ''' Compute the truth value of NOT x element-wise. Args: @@ -113,7 +115,12 @@ def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... + + +alltrue = all +sometrue = any # logic funcs (binary) @@ -121,87 +128,28 @@ def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: def wrap_logic_func_binary(func): @wraps(func) - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return func(x.value, y.value, *args, **kwargs) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -@wrap_logic_func_binary -def equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.equal(x, y) - - -@wrap_logic_func_binary -def not_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.not_equal(x, y) - - -@wrap_logic_func_binary -def greater(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.greater(x, y) - - -@wrap_logic_func_binary -def greater_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.greater_equal(x, y) - - -@wrap_logic_func_binary -def less(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.less(x, y) - - -@wrap_logic_func_binary -def less_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: - return jnp.less_equal(x, y) - - -@wrap_logic_func_binary -def array_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ - bool, Array]: - return jnp.array_equal(x, y) - - -@wrap_logic_func_binary -def isclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], - rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: - return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -@wrap_logic_func_binary -def allclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], - rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: - return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -@wrap_logic_func_binary -def logical_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ - bool, Array]: - return jnp.logical_and(x, y) - - -@wrap_logic_func_binary -def logical_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ - bool, Array]: - return jnp.logical_or(x, y) - - -@wrap_logic_func_binary -def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ - bool, Array]: - return jnp.logical_xor(x, y) - - -# docs for functions above -equal.__doc__ = ''' + def decorator(*args, **kwargs): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_logic_func_binary(jnp.equal) +def equal( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[bool, Array]: + ''' Return (x == y) element-wise and have the same unit if x and y are Quantity. Args: @@ -210,9 +158,16 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... -not_equal.__doc__ = ''' + +@wrap_logic_func_binary(jnp.not_equal) +def not_equal( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[bool, Array]: + ''' Return (x != y) element-wise and have the same unit if x and y are Quantity. Args: @@ -221,9 +176,16 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... + -greater.__doc__ = ''' +@wrap_logic_func_binary(jnp.greater) +def greater( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[bool, Array]: + ''' Return (x > y) element-wise and have the same unit if x and y are Quantity. Args: @@ -232,9 +194,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... + -greater_equal.__doc__ = ''' +@wrap_logic_func_binary(jnp.greater_equal) +def greater_equal( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Return (x >= y) element-wise and have the same unit if x and y are Quantity. Args: @@ -243,9 +213,16 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... + -less.__doc__ = ''' +@wrap_logic_func_binary(jnp.less) +def less( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[bool, Array]: + ''' Return (x < y) element-wise and have the same unit if x and y are Quantity. Args: @@ -254,9 +231,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... -less_equal.__doc__ = ''' + +@wrap_logic_func_binary(jnp.less_equal) +def less_equal( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Return (x <= y) element-wise and have the same unit if x and y are Quantity. Args: @@ -265,9 +250,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... -array_equal.__doc__ = ''' + +@wrap_logic_func_binary(jnp.array_equal) +def array_equal( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. Args: @@ -276,9 +269,19 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' - -isclose.__doc__ = ''' + ''' + ... + + +@wrap_logic_func_binary(jnp.isclose) +def isclose( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False +) -> Union[bool, Array]: + ''' Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. Args: @@ -290,9 +293,19 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' - -allclose.__doc__ = ''' + ''' + ... + + +@wrap_logic_func_binary(jnp.allclose) +def allclose( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False +) -> Union[bool, Array]: + ''' Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. Args: @@ -304,9 +317,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: bool: boolean result -''' + ''' + ... -logical_and.__doc__ = ''' + +@wrap_logic_func_binary(jnp.logical_and) +def logical_and( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. Args: @@ -316,9 +337,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... -logical_or.__doc__ = ''' + +@wrap_logic_func_binary(jnp.logical_or) +def logical_or( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. Args: @@ -328,9 +357,17 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... -logical_xor.__doc__ = ''' + +@wrap_logic_func_binary(jnp.logical_xor) +def logical_xor( + x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike] +) -> Union[ + bool, Array]: + ''' Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. Args: @@ -340,4 +377,5 @@ def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst Returns: Union[bool, jax.Array]: bool or array -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py index d9926ad..3ae2a6f 100644 --- a/brainunit/math/_compat_numpy_funcs_match_unit.py +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -35,46 +35,38 @@ def wrap_math_funcs_match_unit_binary(func): @wraps(func) - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) + def decorator(*args, **kwargs): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -@wrap_math_funcs_match_unit_binary -def add(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: - return jnp.add(x, y) + f.__module__ = 'brainunit.math' + return f -@wrap_math_funcs_match_unit_binary -def subtract(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: - return jnp.subtract(x, y) + return decorator -@wrap_math_funcs_match_unit_binary -def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: - return jnp.nextafter(x, y) - - -# docs for the functions above -add.__doc__ = ''' +@wrap_math_funcs_match_unit_binary(jnp.add) +def add( + x: Union[Quantity, Array], + y: Union[Quantity, Array] +) -> Union[Quantity, Array]: + ''' Add arguments element-wise. Args: @@ -83,9 +75,16 @@ def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Qua Returns: Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. -''' + ''' + ... -subtract.__doc__ = ''' + +@wrap_math_funcs_match_unit_binary(jnp.subtract) +def subtract( + x: Union[Quantity, Array], + y: Union[Quantity, Array] +) -> Union[Quantity, Array]: + ''' Subtract arguments element-wise. Args: @@ -94,9 +93,16 @@ def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Qua Returns: Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. -''' + ''' + ... + -nextafter.__doc__ = ''' +@wrap_math_funcs_match_unit_binary(jnp.nextafter) +def nextafter( + x: Union[Quantity, Array], + y: Union[Quantity, Array] +) -> Union[Quantity, Array]: + ''' Return the next floating-point value after `x1` towards `x2`. Args: @@ -105,4 +111,5 @@ def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Qua Returns: Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_remove_unit.py b/brainunit/math/_compat_numpy_funcs_remove_unit.py index afea533..cc649b1 100644 --- a/brainunit/math/_compat_numpy_funcs_remove_unit.py +++ b/brainunit/math/_compat_numpy_funcs_remove_unit.py @@ -35,38 +35,22 @@ # ------------------------------ def wrap_math_funcs_remove_unit_unary(func): @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return func(x.value, *args, **kwargs) - else: - return func(x, *args, **kwargs) + def decorator(*args, **kwargs): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return func(x.value, *args, **kwargs) + else: + return func(x, *args, **kwargs) - f.__module__ = 'brainunit.math' - return f + f.__module__ = 'brainunit.math' + return f - -@wrap_math_funcs_remove_unit_unary -def signbit(x: Union[Array, Quantity]) -> Array: - return jnp.signbit(x) - - -@wrap_math_funcs_remove_unit_unary -def sign(x: Union[Array, Quantity]) -> Array: - return jnp.sign(x) - - -@wrap_math_funcs_remove_unit_unary -def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: - return jnp.histogram(x) - - -@wrap_math_funcs_remove_unit_unary -def bincount(x: Union[Array, Quantity]) -> Array: - return jnp.bincount(x) + return decorator -# docs for the functions above -signbit.__doc__ = ''' +@wrap_math_funcs_remove_unit_unary(jnp.signbit) +def signbit(x: Union[Array, Quantity]) -> Array: + ''' Returns element-wise True where signbit is set (less than zero). Args: @@ -74,9 +58,13 @@ def bincount(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -sign.__doc__ = ''' + +@wrap_math_funcs_remove_unit_unary(jnp.sign) +def sign(x: Union[Array, Quantity]) -> Array: + ''' Returns the sign of each element in the input array. Args: @@ -84,9 +72,13 @@ def bincount(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -histogram.__doc__ = ''' + +@wrap_math_funcs_remove_unit_unary(jnp.histogram) +def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: + ''' Compute the histogram of a set of data. Args: @@ -94,9 +86,13 @@ def bincount(x: Union[Array, Quantity]) -> Array: Returns: tuple[jax.Array]: Tuple of arrays (hist, bin_edges) -''' + ''' + ... -bincount.__doc__ = ''' + +@wrap_math_funcs_remove_unit_unary(jnp.bincount) +def bincount(x: Union[Array, Quantity]) -> Array: + ''' Count number of occurrences of each value in array of non-negative integers. Args: @@ -104,49 +100,34 @@ def bincount(x: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... # math funcs remove unit (binary) # ------------------------------- def wrap_math_funcs_remove_unit_binary(func): @wraps(func) - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) + def decorator(*args, **kwargs): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) - f.__module__ = 'brainunit.math' - return f + f.__module__ = 'brainunit.math' + return f - -@wrap_math_funcs_remove_unit_binary -def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.corrcoef(x, y) - - -@wrap_math_funcs_remove_unit_binary -def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: - return jnp.correlate(x, y) - - -@wrap_math_funcs_remove_unit_binary -def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: - return jnp.cov(x, y) - - -@wrap_math_funcs_remove_unit_binary -def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: - return jnp.digitize(x, bins) + return decorator -# docs for the functions above -corrcoef.__doc__ = ''' +@wrap_math_funcs_remove_unit_binary(jnp.corrcoef) +def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + ''' Return Pearson product-moment correlation coefficients. Args: @@ -155,9 +136,13 @@ def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -correlate.__doc__ = ''' + +@wrap_math_funcs_remove_unit_binary(jnp.correlate) +def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + ''' Cross-correlation of two sequences. Args: @@ -166,9 +151,13 @@ def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -cov.__doc__ = ''' + +@wrap_math_funcs_remove_unit_binary(jnp.cov) +def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: + ''' Covariance matrix. Args: @@ -177,9 +166,13 @@ def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... -digitize.__doc__ = ''' + +@wrap_math_funcs_remove_unit_binary(jnp.digitize) +def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: + ''' Return the indices of the bins to which each value in input array belongs. Args: @@ -188,4 +181,5 @@ def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: Returns: jax.Array: an array -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py index 776450f..2857a23 100644 --- a/brainunit/math/_compat_numpy_funcs_window.py +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -29,36 +29,37 @@ def wrap_window_funcs(func): @wraps(func) - def f(*args, **kwargs): - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_window_funcs + def decorator(*args, **kwargs): + def f(*args, **kwargs): + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + return decorator + +@wrap_window_funcs(jnp.bartlett) def bartlett(M: int) -> Array: - return jnp.bartlett(M) + ... -@wrap_window_funcs +@wrap_window_funcs(jnp.blackman) def blackman(M: int) -> Array: - return jnp.blackman(M) + ... -@wrap_window_funcs +@wrap_window_funcs(jnp.hamming) def hamming(M: int) -> Array: - return jnp.hamming(M) + ... -@wrap_window_funcs +@wrap_window_funcs(jnp.hanning) def hanning(M: int) -> Array: - return jnp.hanning(M) + ... -@wrap_window_funcs +@wrap_window_funcs(jnp.kaiser) def kaiser(M: int, beta: float) -> Array: - return jnp.kaiser(M, beta) + ... # docs for functions above diff --git a/brainunit/math/_compat_numpy_linear_algebra.py b/brainunit/math/_compat_numpy_linear_algebra.py index 88f27e9..af7118e 100644 --- a/brainunit/math/_compat_numpy_linear_algebra.py +++ b/brainunit/math/_compat_numpy_linear_algebra.py @@ -30,59 +30,27 @@ ] - - # linear algebra # -------------- -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +@wrap_math_funcs_change_unit_binary(jnp.dot, lambda x, y: x * y) def dot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.dot(a, b) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.vdot(a, b) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.inner(a, b) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.outer(a, b) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.kron(a, b) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.matmul(a, b) - - -@wrap_math_funcs_keep_unit_unary -def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: - return jnp.trace(a) - - -# docs for functions above -dot.__doc__ = ''' + ''' Dot product of two arrays or quantities. - + Args: a: array_like, Quantity b: array_like, Quantity - + Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' + ''' + ... + -vdot.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y) +def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Return the dot product of two vectors or quantities. Args: @@ -91,9 +59,13 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' + ''' + ... + -inner.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.inner, lambda x, y: x * y) +def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Inner product of two arrays or quantities. Args: @@ -102,9 +74,13 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' + ''' + ... + -outer.__doc__ = ''' +@wrap_math_funcs_change_unit_binary(jnp.outer, lambda x, y: x * y) +def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Compute the outer product of two vectors or quantities. Args: @@ -113,9 +89,13 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' + ''' + ... -kron.__doc__ = ''' + +@wrap_math_funcs_change_unit_binary(jnp.kron, lambda x, y: x * y) +def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Compute the Kronecker product of two arrays or quantities. Args: @@ -124,9 +104,13 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' + ''' + ... -matmul.__doc__ = ''' + +@wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) +def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Matrix product of two arrays or quantities. Args: @@ -135,9 +119,13 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' + ''' + ... -trace.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.trace) +def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: + ''' Return the sum of the diagonal elements of a matrix or quantity. Args: @@ -146,4 +134,5 @@ def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: Returns: Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. -''' + ''' + ... diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 0deb591..817c9ef 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -261,36 +261,14 @@ def intersect1d( return result -@wrap_math_funcs_keep_unit_unary -def nan_to_num(x: Union[bst.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, - neginf: float = -jnp.inf) -> Union[jax.Array, Quantity]: - return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) - - -@wrap_math_funcs_keep_unit_unary -def rot90(m: Union[bst.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ - jax.Array, Quantity]: - return jnp.rot90(m, k=k, axes=axes) - - -@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def tensordot(a: Union[bst.typing.ArrayLike, Quantity], b: Union[bst.typing.ArrayLike, Quantity], - axes: Union[int, Tuple[int, int]] = 2) -> Union[jax.Array, Quantity]: - return jnp.tensordot(a, b, axes=axes) - - -@_compatible_with_quantity(return_quantity=False) -def nanargmax(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: - return jnp.nanargmax(a, axis=axis) - - -@_compatible_with_quantity(return_quantity=False) -def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: - return jnp.nanargmin(a, axis=axis) - - -# docs for functions above -nan_to_num.__doc__ = ''' +@wrap_math_funcs_keep_unit_unary(jnp.nan_to_num) +def nan_to_num( + x: Union[bst.typing.ArrayLike, Quantity], + nan: float = 0.0, + posinf: float = jnp.inf, + neginf: float = -jnp.inf +) -> Union[jax.Array, Quantity]: + ''' Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. Args: @@ -301,9 +279,18 @@ def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax Returns: array with NaNs replaced by zero and infinities replaced by large finite numbers. -''' + ''' + ... -nanargmax.__doc__ = ''' + +@wrap_math_funcs_keep_unit_unary(jnp.rot90) +def rot90( + m: Union[bst.typing.ArrayLike, Quantity], + k: int = 1, + axes: Tuple[int, int] = (0, 1) +) -> Union[ + jax.Array, Quantity]: + ''' Return the index of the maximum value in an array, ignoring NaNs. Args: @@ -314,9 +301,17 @@ def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax Returns: index of the maximum value in the array. -''' + ''' + ... -nanargmin.__doc__ = ''' + +@wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) +def tensordot( + a: Union[bst.typing.ArrayLike, Quantity], + b: Union[bst.typing.ArrayLike, Quantity], + axes: Union[int, Tuple[int, int]] = 2 +) -> Union[jax.Array, Quantity]: + ''' Return the index of the minimum value in an array, ignoring NaNs. Args: @@ -327,9 +322,16 @@ def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax Returns: index of the minimum value in the array. -''' + ''' + ... + -rot90.__doc__ = ''' +@_compatible_with_quantity(jnp.nanargmax, return_quantity=False) +def nanargmax( + a: Union[bst.typing.ArrayLike, Quantity], + axis: int = None +) -> jax.Array: + ''' Rotate an array by 90 degrees in the plane specified by axes. Args: @@ -339,9 +341,16 @@ def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax Returns: rotated array. -''' + ''' + ... -tensordot.__doc__ = ''' + +@_compatible_with_quantity(jnp.nanargmin, return_quantity=False) +def nanargmin( + a: Union[bst.typing.ArrayLike, Quantity], + axis: int = None +) -> jax.Array: + ''' Compute tensor dot product along specified axes for arrays. Args: @@ -351,4 +360,5 @@ def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax Returns: tensor dot product of the two arrays. -''' + ''' + ... diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index 61242e0..5f68a0d 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -32,10 +32,11 @@ def _is_leaf(a): def _compatible_with_quantity( + fun: Callable, return_quantity: bool = True, ): - def decorator(fun: Callable) -> Callable: - @functools.wraps(fun) + @functools.wraps(fun) + def decorator(*args, **kwargs) -> Callable: def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: unit = None if isinstance(args[0], Quantity):