diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 0dbc908..09dda8f 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import functools from collections.abc import Sequence from functools import wraps -from typing import (Callable, Union, Optional) +from typing import (Callable, Union, Optional, Any) import brainstate as bst import jax @@ -162,85 +162,335 @@ def f(*args, unit: Unit = None, **kwargs): ones = wrap_array_creation_function(jnp.ones) zeros = wrap_array_creation_function(jnp.zeros) +# docs for full, eye, identity, tri, empty, ones, zeros + +full.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. +else return an array of `shape` filled with `fill_value`. + + Args: + shape: sequence of integers, describing the shape of the output array. + fill_value: the value to fill the new array with. + dtype: the type of the output array, or `None`. If not `None`, `fill_value` + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +eye.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. +else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +identity.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. +else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +tri.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. +else return a triangular matrix of `shape`. + + Args: + n: the number of rows in the output array. + m: the number of columns with default being `n`. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# empty +empty.__doc__ = """ +Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. +else return an array of `shape` with uninitialized values. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be of type `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# ones +ones.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. +else return an array of `shape` filled with 1. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# zeros +zeros.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. +else return an array of `shape` filled with 0. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + @set_module_as('brainunit.math') -def full_like(a, fill_value, dtype=None, shape=None): - if isinstance(a, Quantity) and isinstance(fill_value, Quantity): - fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.') - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity): - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) +def full_like(a: Union[Quantity, jax.Array, np.ndarray], + fill_value: Union[jax.Array, np.ndarray], + unit: Unit = None, + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `a` filled with `fill_value`. + + Args: + a: array_like, Quantity, shape, or dtype + fill_value: scalar or array_like + unit: Unit, optional + dtype: data-type, optional + shape: sequence of ints, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like') + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def diag(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.diag(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.diag(a, k=k) +def diag(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Extract a diagonal or construct a diagonal array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.diag(a.value, k=k) * unit + else: + return jnp.diag(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for diag') + return jnp.diag(a, k=k) @set_module_as('brainunit.math') -def tril(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.tril(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.tril(a, k=k) +def tril(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Lower triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.tril(a.value, k=k) * unit + else: + return jnp.tril(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for tril') + return jnp.tril(a, k=k) @set_module_as('brainunit.math') -def triu(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.triu(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.triu(a, k=k) +def triu(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Upper triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.triu(a.value, k=k) * unit + else: + return jnp.triu(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for triu') + return jnp.triu(a, k=k) @set_module_as('brainunit.math') -def empty_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.empty_like(a, dtype=dtype, shape=shape) +def empty_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `a` with uninitialized values. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for empty_like') + return jnp.empty_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def ones_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.ones_like(a, dtype=dtype, shape=shape) +def ones_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. + else return an array of `a` filled with 1. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for ones_like') + return jnp.ones_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def zeros_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.zeros_like(a, dtype=dtype, shape=shape) +def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. + else return an array of `a` filled with 0. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for zeros_like') + return jnp.zeros_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') def asarray( - a, + a: Union[Quantity, jax.Array, np.ndarray, Sequence[Quantity]], dtype: Optional[bst.typing.DTypeLike] = None, order: Optional[str] = None, unit: Optional[Unit] = None, -): +) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if isinstance(a, Quantity): @@ -265,6 +515,19 @@ def asarray( @set_module_as('brainunit.math') def arange(*args, **kwargs): + ''' + Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity, optional + stop: number, Quantity, optional + step: number, optional + dtype: dtype, optional + unit: Unit, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' # arange has a bit of a complicated argument structure unfortunately # we leave the actual checking of the number of arguments to numpy, though @@ -343,7 +606,26 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): +def linspace(start: Union[Quantity, jax.Array, np.ndarray], + stop: Union[Quantity, jax.Array, np.ndarray], + num: int = 50, + endpoint: bool = True, + retstep: bool = False, + dtype: bst.typing.DTypeLike = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + retstep: bool, optional + dtype: dtype, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' fail_for_dimension_mismatch( start, stop, @@ -360,7 +642,26 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): @set_module_as('brainunit.math') -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): +def logspace(start: Union[Quantity, jax.Array, np.ndarray], + stop: Union[Quantity, jax.Array, np.ndarray], + num: int = 50, + endpoint: bool = True, + base: float = 10.0, + dtype: bst.typing.DTypeLike = None): + ''' + Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + base: float, optional + dtype: dtype, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' fail_for_dimension_mismatch( start, stop, @@ -377,7 +678,22 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): @set_module_as('brainunit.math') -def fill_diagonal(a, val, wrap=False, inplace=True): +def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], + val: Union[Quantity, jax.Array, np.ndarray], + wrap: bool = False, + inplace: bool = True) -> Union[Quantity, jax.Array]: + ''' + Fill the main diagonal of the given array of `a` with `val`. + + Args: + a: array_like, Quantity + val: scalar, Quantity + wrap: bool, optional + inplace: bool, optional + + Returns: + out: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + ''' if isinstance(a, Quantity) and isinstance(val, Quantity): fail_for_dimension_mismatch(a, val) return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) @@ -390,7 +706,20 @@ def fill_diagonal(a, val, wrap=False, inplace=True): @set_module_as('brainunit.math') -def array_split(ary, indices_or_sections, axis=0): +def array_split(ary: Union[Quantity, jax.Array, np.ndarray], + indices_or_sections: Union[int, jax.Array, np.ndarray], + axis: int = 0) -> Union[Quantity, jax.Array]: + ''' + Split an array into multiple sub-arrays. + + Args: + ary: array_like, Quantity + indices_or_sections: int, array_like + axis: int, optional + + Returns: + out: Quantity if `ary` is a Quantity, else an array. + ''' if isinstance(ary, Quantity): return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), unit=ary.unit) elif isinstance(ary, (jax.Array, np.ndarray)): @@ -400,7 +729,22 @@ def array_split(ary, indices_or_sections, axis=0): @set_module_as('brainunit.math') -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): +def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], + copy: bool = True, + sparse: bool = False, + indexing: str = 'xy'): + ''' + Return coordinate matrices from coordinate vectors. + + Args: + xi: array_like, Quantity + copy: bool, optional + sparse: bool, optional + indexing: str, optional + + Returns: + out: Quantity if `xi` are Quantities that have the same unit, else an array. + ''' from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): fail_for_dimension_mismatch(*xi) @@ -412,7 +756,20 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): @set_module_as('brainunit.math') -def vander(x, N=None, increasing=False): +def vander(x: Union[Quantity, jax.Array, np.ndarray], + N: bool=None, + increasing: bool=False) -> Union[Quantity, jax.Array]: + ''' + Generate a Vandermonde matrix. + + Args: + x: array_like, Quantity + N: int, optional + increasing: bool, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' if isinstance(x, Quantity): return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) elif isinstance(x, (jax.Array, np.ndarray)):