diff --git a/brainunit/__init__.py b/brainunit/__init__.py index dbd9a17..2b2cc55 100644 --- a/brainunit/__init__.py +++ b/brainunit/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +__version__ = "0.0.1" from ._base import * from ._base import __all__ as _base_all diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 11be699..9309ceb 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -881,7 +881,7 @@ def test_unit_discarding_functions(): """ Test functions that discard units. """ - from brainunit import ones_like, zeros_like + from brainunit.math import ones_like, zeros_like values = [3 * mV, np.array([1, 2]) * mV, np.arange(12).reshape(3, 4) * mV] for value in values: @@ -897,7 +897,7 @@ def test_unitsafe_functions(): """ Test the unitsafe functions wrapping their numpy counterparts. """ - from braincore.math import ( + from brainunit.math import ( arccos, arccosh, arcsin, @@ -966,7 +966,7 @@ def test_special_case_numpy_functions(): """ Test a couple of functions/methods that need special treatment. """ - from braincore.math import diagonal, dot, ravel, trace, where + from brainunit.math import diagonal, dot, ravel, trace, where quadratic_matrix = np.reshape(np.arange(9), (3, 3)) * mV diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py new file mode 100644 index 0000000..5b1a673 --- /dev/null +++ b/brainunit/math/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from ._compat_numpy import * +from ._compat_numpy import __all__ as _compat_numpy_all + +__all__ = _compat_numpy_all diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py new file mode 100644 index 0000000..5291ef9 --- /dev/null +++ b/brainunit/math/_compat_numpy.py @@ -0,0 +1,1196 @@ +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union) + +import jax +import jax.numpy as jnp +import numpy as np +import opt_einsum +from jax import lax +from jax._src.numpy.lax_numpy import _einsum + + +from braincore._common import set_module_as +from brainunit.math._utils import _compatible_with_quantity +from brainunit._base import ( + DIMENSIONLESS, + Quantity, + fail_for_dimension_mismatch, + is_unitless, + _return_check_unitless, + get_unit, +) + +__all__ = [ + # array creation + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + 'array_split', 'meshgrid', 'vander', + + # getting attribute funcs + 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', + 'isnan', 'shape', 'size', + + # math funcs keep unit (unary) + 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', + 'abs', 'round', 'around', 'round_', 'rint', + 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', + 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', + 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', + 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', + + # math funcs keep unit (binary) + 'fmod', 'mod', 'copysign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', + + # math funcs keep unit (n-ary) + 'interp', 'clip', + + # math funcs match unit (binary) + 'add', 'subtract', 'nextafter', + + # math funcs change unit (unary) + 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', + 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', + + # math funcs change unit (binary) + 'multiply', 'divide', 'power', 'cross', 'ldexp', + 'true_divide', 'floor_divide', 'float_power', + 'divmod', 'remainder', 'convolve', + + # math funcs only accept unitless (unary) + 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', + 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', + 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + + # math funcs only accept unitless (binary) + 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', + + # math funcs remove unit (unary) + 'signbit', 'sign', 'histogram', 'bincount', + + # math funcs remove unit (binary) + 'corrcoef', 'correlate', 'cov', 'digitize', + + # array manipulation + 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', + 'diagflat', 'diagonal', 'choose', 'ravel', + + # Elementwise bit operations (unary) + 'bitwise_not', 'invert', 'left_shift', 'right_shift', + + # Elementwise bit operations (binary) + 'bitwise_and', 'bitwise_or', 'bitwise_xor', + + # logic funcs (unary) + 'all', 'any', 'logical_not', + + # logic funcs (binary) + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', + 'logical_or', 'logical_xor', "alltrue", 'sometrue', + + # indexing funcs + 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', + + # constants + 'e', 'pi', 'inf', + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'broadcast_arrays', 'broadcast_shapes', + 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', + 'rot90', 'tensordot', + +] + + +# array creation +# -------------- + +def wrap_array_creation_function(func): + def f(*args, **kwargs): + return Quantity(func(*args, **kwargs)) + + f.__module__ = 'braincore.math' + return f + + +# array creation +# -------------- + +full = wrap_array_creation_function(jnp.full) +full_like = wrap_array_creation_function(jnp.full_like) +eye = wrap_array_creation_function(jnp.eye) +identity = wrap_array_creation_function(jnp.identity) +diag = wrap_array_creation_function(jnp.diag) +tri = wrap_array_creation_function(jnp.tri) +tril = wrap_array_creation_function(jnp.tril) +triu = wrap_array_creation_function(jnp.triu) +empty = wrap_array_creation_function(jnp.empty) +empty_like = wrap_array_creation_function(jnp.empty_like) +ones = wrap_array_creation_function(jnp.ones) +ones_like = wrap_array_creation_function(jnp.ones_like) +zeros = wrap_array_creation_function(jnp.zeros) +zeros_like = wrap_array_creation_function(jnp.zeros_like) +array = wrap_array_creation_function(jnp.array) +asarray = wrap_array_creation_function(jnp.asarray) +arange = wrap_array_creation_function(jnp.arange) +linspace = wrap_array_creation_function(jnp.linspace) +logspace = wrap_array_creation_function(jnp.logspace) +fill_diagonal = wrap_array_creation_function(jnp.fill_diagonal) +array_split = wrap_array_creation_function(jnp.array_split) +meshgrid = wrap_array_creation_function(jnp.meshgrid) +vander = wrap_array_creation_function(jnp.vander) + + +# getting attribute funcs +# ----------------------- + +@set_module_as('braincore.math') +def ndim(a): + if isinstance(a, Quantity): + return a.ndim + else: + return jnp.ndim(a) + + +@set_module_as('braincore.math') +def isreal(a): + if isinstance(a, Quantity): + return a.isreal + else: + return jnp.isreal(a) + + +@set_module_as('braincore.math') +def isscalar(a): + if isinstance(a, Quantity): + return a.isscalar + else: + return jnp.isscalar(a) + + +@set_module_as('braincore.math') +def isfinite(a): + if isinstance(a, Quantity): + return a.isfinite + else: + return jnp.isfinite(a) + + +@set_module_as('braincore.math') +def isinf(a): + if isinstance(a, Quantity): + return a.isinf + else: + return jnp.isinf(a) + + +@set_module_as('braincore.math') +def isnan(a): + if isinstance(a, Quantity): + return a.isnan + else: + return jnp.isnan(a) + + +@set_module_as('braincore.math') +def shape(a): + """ + Return the shape of an array. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also + -------- + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples + -------- + >>> braincore.math.shape(braincore.math.eye(3)) + (3, 3) + >>> braincore.math.shape([[1, 3]]) + (1, 2) + >>> braincore.math.shape([0]) + (1,) + >>> braincore.math.shape(0) + () + + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) + + +@set_module_as('braincore.math') +def size(a, axis=None): + """ + Return the number of elements along a given axis. + + Parameters + ---------- + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns + ------- + element_count : int + Number of elements along the specified axis. + + See Also + -------- + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples + -------- + >>> a = Quantity([[1,2,3], [4,5,6]]) + >>> braincore.math.size(a) + 6 + >>> braincore.math.size(a, 1) + 3 + >>> braincore.math.size(a, 0) + 2 + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] + else: + return np.size(a, axis=axis) + + +# math funcs keep unit (unary) +# ---------------------------- + +def wrap_math_funcs_keep_unit_unary(func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +real = wrap_math_funcs_keep_unit_unary(jnp.real) +imag = wrap_math_funcs_keep_unit_unary(jnp.imag) +conj = wrap_math_funcs_keep_unit_unary(jnp.conj) +conjugate = wrap_math_funcs_keep_unit_unary(jnp.conjugate) +negative = wrap_math_funcs_keep_unit_unary(jnp.negative) +positive = wrap_math_funcs_keep_unit_unary(jnp.positive) +abs = wrap_math_funcs_keep_unit_unary(jnp.abs) +round_ = wrap_math_funcs_keep_unit_unary(jnp.round) +around = wrap_math_funcs_keep_unit_unary(jnp.around) +round = wrap_math_funcs_keep_unit_unary(jnp.round) +rint = wrap_math_funcs_keep_unit_unary(jnp.rint) +floor = wrap_math_funcs_keep_unit_unary(jnp.floor) +ceil = wrap_math_funcs_keep_unit_unary(jnp.ceil) +trunc = wrap_math_funcs_keep_unit_unary(jnp.trunc) +fix = wrap_math_funcs_keep_unit_unary(jnp.fix) +sum = wrap_math_funcs_keep_unit_unary(jnp.sum) +nancumsum = wrap_math_funcs_keep_unit_unary(jnp.nancumsum) +nansum = wrap_math_funcs_keep_unit_unary(jnp.nansum) +cumsum = wrap_math_funcs_keep_unit_unary(jnp.cumsum) +ediff1d = wrap_math_funcs_keep_unit_unary(jnp.ediff1d) +absolute = wrap_math_funcs_keep_unit_unary(jnp.absolute) +fabs = wrap_math_funcs_keep_unit_unary(jnp.fabs) +median = wrap_math_funcs_keep_unit_unary(jnp.median) +nanmin = wrap_math_funcs_keep_unit_unary(jnp.nanmin) +nanmax = wrap_math_funcs_keep_unit_unary(jnp.nanmax) +ptp = wrap_math_funcs_keep_unit_unary(jnp.ptp) +average = wrap_math_funcs_keep_unit_unary(jnp.average) +mean = wrap_math_funcs_keep_unit_unary(jnp.mean) +std = wrap_math_funcs_keep_unit_unary(jnp.std) +nanmedian = wrap_math_funcs_keep_unit_unary(jnp.nanmedian) +nanmean = wrap_math_funcs_keep_unit_unary(jnp.nanmean) +nanstd = wrap_math_funcs_keep_unit_unary(jnp.nanstd) +diff = wrap_math_funcs_keep_unit_unary(jnp.diff) +modf = wrap_math_funcs_keep_unit_unary(jnp.modf) + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) +mod = wrap_math_funcs_keep_unit_binary(jnp.mod) +copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) +heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) +maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) +minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) +fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) +fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) +lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) +gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('braincore.math') +def interp(x, xp, fp, left=None, right=None, period=None): + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +@set_module_as('braincore.math') +def clip(a, a_min, a_max): + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +add = wrap_math_funcs_match_unit_binary(jnp.add) +subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) +nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(func, change_unit_func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) + + +@set_module_as('braincore.math') +def prod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + else: + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + + +@set_module_as('braincore.math') +def nanprod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + + +product = prod + + +@set_module_as('braincore.math') +def cumprod(x, axis=None, dtype=None, out=None): + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('braincore.math') +def nancumprod(x, axis=None, dtype=None, out=None): + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + +var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) +nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) +frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) +sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) +cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) +square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(func, change_unit_func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) +divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) + + +@set_module_as('braincore.math') +def power(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.power(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y, *args, **kwargs), unit=x.unit ** y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + + +cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) +ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) +true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) + + +@set_module_as('braincore.math') +def floor_divide(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.floor_divide(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y, *args, **kwargs), unit=x.unit / y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value, *args, **kwargs), unit=x / y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('braincore.math') +def float_power(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.float_power(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y, *args, **kwargs), unit=x.unit ** y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + + +divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) + + +@set_module_as('braincore.math') +def remainder(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.remainder(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y, *args, **kwargs), unit=x.unit % y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x, y.value, *args, **kwargs), unit=x % y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') + + +convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) + + +# math funcs only accept unitless (unary) +# --------------------------------------- + +def wrap_math_funcs_only_accept_unitless_unary(func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + return func(jnp.array(x.value), *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'braincore.math' + return f + + +exp = wrap_math_funcs_only_accept_unitless_unary(jnp.exp) +exp2 = wrap_math_funcs_only_accept_unitless_unary(jnp.exp2) +expm1 = wrap_math_funcs_only_accept_unitless_unary(jnp.expm1) +log = wrap_math_funcs_only_accept_unitless_unary(jnp.log) +log10 = wrap_math_funcs_only_accept_unitless_unary(jnp.log10) +log1p = wrap_math_funcs_only_accept_unitless_unary(jnp.log1p) +log2 = wrap_math_funcs_only_accept_unitless_unary(jnp.log2) +arccos = wrap_math_funcs_only_accept_unitless_unary(jnp.arccos) +arccosh = wrap_math_funcs_only_accept_unitless_unary(jnp.arccosh) +arcsin = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsin) +arcsinh = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsinh) +arctan = wrap_math_funcs_only_accept_unitless_unary(jnp.arctan) +arctanh = wrap_math_funcs_only_accept_unitless_unary(jnp.arctanh) +cos = wrap_math_funcs_only_accept_unitless_unary(jnp.cos) +cosh = wrap_math_funcs_only_accept_unitless_unary(jnp.cosh) +sin = wrap_math_funcs_only_accept_unitless_unary(jnp.sin) +sinc = wrap_math_funcs_only_accept_unitless_unary(jnp.sinc) +sinh = wrap_math_funcs_only_accept_unitless_unary(jnp.sinh) +tan = wrap_math_funcs_only_accept_unitless_unary(jnp.tan) +tanh = wrap_math_funcs_only_accept_unitless_unary(jnp.tanh) +deg2rad = wrap_math_funcs_only_accept_unitless_unary(jnp.deg2rad) +rad2deg = wrap_math_funcs_only_accept_unitless_unary(jnp.rad2deg) +degrees = wrap_math_funcs_only_accept_unitless_unary(jnp.degrees) +radians = wrap_math_funcs_only_accept_unitless_unary(jnp.radians) +angle = wrap_math_funcs_only_accept_unitless_unary(jnp.angle) +percentile = wrap_math_funcs_only_accept_unitless_unary(jnp.percentile) +nanpercentile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanpercentile) +quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) +nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) + + +# math funcs only accept unitless (binary) +# ---------------------------------------- + +def wrap_math_funcs_only_accept_unitless_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + fail_for_dimension_mismatch( + y, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=y, + ) + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'braincore.math' + return f + + +hypot = wrap_math_funcs_only_accept_unitless_binary(jnp.hypot) +arctan2 = wrap_math_funcs_only_accept_unitless_binary(jnp.arctan2) +logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) +logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) + + +# math funcs remove unit (unary) +# ------------------------------ +def wrap_math_funcs_remove_unit_unary(func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return func(x.value, *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'braincore.math' + return f + + +signbit = wrap_math_funcs_remove_unit_unary(jnp.signbit) +sign = wrap_math_funcs_remove_unit_unary(jnp.sign) +histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) +bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) + + +# math funcs remove unit (binary) +# ------------------------------- +def wrap_math_funcs_remove_unit_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'braincore.math' + return f + + +corrcoef = wrap_math_funcs_remove_unit_binary(jnp.corrcoef) +correlate = wrap_math_funcs_remove_unit_binary(jnp.correlate) +cov = wrap_math_funcs_remove_unit_binary(jnp.cov) +digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) + +# array manipulation +# ------------------ + +reshape = _compatible_with_quantity(jnp.reshape) +moveaxis = _compatible_with_quantity(jnp.moveaxis) +transpose = _compatible_with_quantity(jnp.transpose) +swapaxes = _compatible_with_quantity(jnp.swapaxes) +concatenate = _compatible_with_quantity(jnp.concatenate) +stack = _compatible_with_quantity(jnp.stack) +vstack = _compatible_with_quantity(jnp.vstack) +row_stack = vstack +hstack = _compatible_with_quantity(jnp.hstack) +dstack = _compatible_with_quantity(jnp.dstack) +column_stack = _compatible_with_quantity(jnp.column_stack) +split = _compatible_with_quantity(jnp.split) +dsplit = _compatible_with_quantity(jnp.dsplit) +hsplit = _compatible_with_quantity(jnp.hsplit) +vsplit = _compatible_with_quantity(jnp.vsplit) +tile = _compatible_with_quantity(jnp.tile) +repeat = _compatible_with_quantity(jnp.repeat) +unique = _compatible_with_quantity(jnp.unique) +append = _compatible_with_quantity(jnp.append) +flip = _compatible_with_quantity(jnp.flip) +fliplr = _compatible_with_quantity(jnp.fliplr) +flipud = _compatible_with_quantity(jnp.flipud) +roll = _compatible_with_quantity(jnp.roll) +atleast_1d = _compatible_with_quantity(jnp.atleast_1d) +atleast_2d = _compatible_with_quantity(jnp.atleast_2d) +atleast_3d = _compatible_with_quantity(jnp.atleast_3d) +expand_dims = _compatible_with_quantity(jnp.expand_dims) +squeeze = _compatible_with_quantity(jnp.squeeze) +sort = _compatible_with_quantity(jnp.sort) + +max = _compatible_with_quantity(jnp.max) +min = _compatible_with_quantity(jnp.min) + +amax = max +amin = min + +choose = _compatible_with_quantity(jnp.choose) +block = _compatible_with_quantity(jnp.block) +compress = _compatible_with_quantity(jnp.compress) +diagflat = _compatible_with_quantity(jnp.diagflat) + +# return jax.numpy.Array, not Quantity +argsort = _compatible_with_quantity(jnp.argsort, return_quantity=False) +argmax = _compatible_with_quantity(jnp.argmax, return_quantity=False) +argmin = _compatible_with_quantity(jnp.argmin, return_quantity=False) +argwhere = _compatible_with_quantity(jnp.argwhere, return_quantity=False) +nonzero = _compatible_with_quantity(jnp.nonzero, return_quantity=False) +flatnonzero = _compatible_with_quantity(jnp.flatnonzero, return_quantity=False) +searchsorted = _compatible_with_quantity(jnp.searchsorted, return_quantity=False) +extract = _compatible_with_quantity(jnp.extract, return_quantity=False) +count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) + + +def wrap_function_to_method(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'braincore.math' + return f + + +diagonal = wrap_function_to_method(jnp.diagonal) +ravel = wrap_function_to_method(jnp.ravel) + + +# Elementwise bit operations (unary) +# ---------------------------------- + +def wrap_elementwise_bit_operation_unary(func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected integers, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) +invert = wrap_elementwise_bit_operation_unary(jnp.invert) +left_shift = wrap_elementwise_bit_operation_unary(jnp.left_shift) +right_shift = wrap_elementwise_bit_operation_unary(jnp.right_shift) + + +# Elementwise bit operations (binary) +# ----------------------------------- + +def wrap_elementwise_bit_operation_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) or isinstance(y, Quantity): + raise ValueError(f'Expected integers, got {x} and {y}') + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) +bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) +bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) + + +# logic funcs (unary) +# ------------------- + +def wrap_logic_func_unary(func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected booleans, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +all = wrap_logic_func_unary(jnp.all) +any = wrap_logic_func_unary(jnp.any) +alltrue = all +sometrue = any +logical_not = wrap_logic_func_unary(jnp.logical_not) + + +# logic funcs (binary) +# -------------------- + +def wrap_logic_func_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'braincore.math' + return f + + +equal = wrap_logic_func_binary(jnp.equal) +not_equal = wrap_logic_func_binary(jnp.not_equal) +greater = wrap_logic_func_binary(jnp.greater) +greater_equal = wrap_logic_func_binary(jnp.greater_equal) +less = wrap_logic_func_binary(jnp.less) +less_equal = wrap_logic_func_binary(jnp.less_equal) +array_equal = wrap_logic_func_binary(jnp.array_equal) +isclose = wrap_logic_func_binary(jnp.isclose) +allclose = wrap_logic_func_binary(jnp.allclose) +logical_and = wrap_logic_func_binary(jnp.logical_and) + +logical_or = wrap_logic_func_binary(jnp.logical_or) +logical_xor = wrap_logic_func_binary(jnp.logical_xor) + + +# indexing funcs +# -------------- +@set_module_as('braincore.math') +def where(condition, *args, **kwds): # pylint: disable=C0111 + condition = jnp.asarray(condition) + if len(args) == 0: + # nothing to do + return jnp.where(condition, *args, **kwds) + elif len(args) == 2: + # check that x and y have the same dimensions + fail_for_dimension_mismatch( + args[0], args[1], "x and y need to have the same dimensions" + ) + new_args = [] + for arg in args: + if isinstance(arg, Quantity): + new_args.append(arg.value) + if is_unitless(args[0]): + if len(new_args) == 2: + return jnp.where(condition, *new_args, **kwds) + else: + return jnp.where(condition, *args, **kwds) + else: + # as both arguments have the same unit, just use the first one's + dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] + return Quantity.with_units( + jnp.where(condition, *dimensionless_args), args[0].unit + ) + else: + # illegal number of arguments + if len(args) == 1: + raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") + elif len(args) > 2: + raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) + + +tril_indices = jnp.tril_indices + + +@set_module_as('braincore.math') +def tril_indices_from(arr, k=0): + if isinstance(arr, Quantity): + return jnp.tril_indices_from(arr.value, k=k) + else: + return jnp.tril_indices_from(arr, k=k) + + +triu_indices = jnp.triu_indices + + +@set_module_as('braincore.math') +def triu_indices_from(arr, k=0): + if isinstance(arr, Quantity): + return jnp.triu_indices_from(arr.value, k=k) + else: + return jnp.triu_indices_from(arr, k=k) + + +@set_module_as('braincore.math') +def take(a, indices, axis=None, mode=None): + if isinstance(a, Quantity): + return a.take(indices, axis=axis, mode=mode) + else: + return jnp.take(a, indices, axis=axis, mode=mode) + + +@set_module_as('braincore.math') +def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quantity, jax.Array, np.ndarray], default=0): + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(choice, Quantity) for choice in choicelist): + if origin_any(choice.unit != choicelist[0].unit for choice in choicelist): + raise ValueError("All choices must have the same unit") + else: + return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), + unit=choicelist[0].unit) + elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): + return jnp.select(condlist, choicelist, default=default) + else: + raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") + + +# window funcs +# ------------ + +def wrap_window_funcs(func): + def f(*args, **kwargs): + return Quantity(func(*args, **kwargs)) + + f.__module__ = 'braincore.math' + return f + + +bartlett = wrap_window_funcs(jnp.bartlett) +blackman = wrap_window_funcs(jnp.blackman) +hamming = wrap_window_funcs(jnp.hamming) +hanning = wrap_window_funcs(jnp.hanning) +kaiser = wrap_window_funcs(jnp.kaiser) + +# constants +# --------- +e = jnp.e +pi = jnp.pi +inf = jnp.inf + +# linear algebra +# -------------- +dot = wrap_math_funcs_change_unit_binary(jnp.dot, lambda x, y: x * y) +vdot = wrap_math_funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y) +inner = wrap_math_funcs_change_unit_binary(jnp.inner, lambda x, y: x * y) +outer = wrap_math_funcs_change_unit_binary(jnp.outer, lambda x, y: x * y) +kron = wrap_math_funcs_change_unit_binary(jnp.kron, lambda x, y: x * y) +matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) +trace = wrap_math_funcs_keep_unit_unary(jnp.trace) + +# data types +# ---------- +dtype = jnp.dtype + + +@set_module_as('braincore.math') +def finfo(a): + if isinstance(a, Quantity): + return jnp.finfo(a.value) + else: + return jnp.finfo(a) + + +@set_module_as('braincore.math') +def iinfo(a): + if isinstance(a, Quantity): + return jnp.iinfo(a.value) + else: + return jnp.iinfo(a) + + +# more +# ---- +@set_module_as('braincore.math') +def broadcast_arrays(*args): + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(arg, Quantity) for arg in args): + if origin_any(arg.unit != args[0].unit for arg in args): + raise ValueError("All arguments must have the same unit") + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit) + elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): + return jnp.broadcast_arrays(*args) + else: + raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + + +broadcast_shapes = jnp.broadcast_shapes + + +@set_module_as('braincore.math') +def einsum( + subscripts, /, + *operands, + out: None = None, + optimize: Union[str, bool] = "optimal", + precision: jax.lax.PrecisionLike = None, + preferred_element_type: Union[jax.typing.DTypeLike, None] = None, + _dot_general: Callable[..., jax.Array] = lax.dot_general, +) -> Union[jax.Array, Quantity]: + operands = (subscripts, *operands) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") + spec = operands[0] if isinstance(operands[0], str) else None + optimize = 'optimal' if optimize is True else optimize + + # Allow handling of shape polymorphism + non_constant_dim_types = { + type(d) for op in operands if not isinstance(op, str) + for d in np.shape(op) if not jax.core.is_constant_dim(d) + } + if not non_constant_dim_types: + contract_path = opt_einsum.contract_path + else: + from jax._src.numpy.lax_numpy import _default_poly_einsum_handler + contract_path = _default_poly_einsum_handler + + operands, contractions = contract_path( + *operands, einsum_call=True, use_blas=True, optimize=optimize) + + unit = None + for i in range(len(contractions) - 1): + if contractions[i][4] == 'False': + + fail_for_dimension_mismatch( + Quantity([], unit=unit), operands[i + 1], 'einsum' + ) + elif contractions[i][4] == 'DOT' or \ + contractions[i][4] == 'TDOT' or \ + contractions[i][4] == 'GEMM' or \ + contractions[i][4] == 'OUTER/EINSUM': + if i == 0: + if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): + unit = operands[i].unit * operands[i + 1].unit + elif isinstance(operands[i], Quantity): + unit = operands[i].unit + elif isinstance(operands[i + 1], Quantity): + unit = operands[i + 1].unit + else: + if isinstance(operands[i + 1], Quantity): + unit = unit * operands[i + 1].unit + + contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + + einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + if spec is not None: + einsum = jax.named_call(einsum, name=spec) + operands = [op.value if isinstance(op, Quantity) else op for op in operands] + r = einsum(operands, contractions, precision, # type: ignore[operator] + preferred_element_type, _dot_general) + if unit is not None: + return Quantity(r, unit=unit) + else: + return r + + +@set_module_as('braincore.math') +def gradient( + f: Union[jax.Array, np.ndarray, Quantity], + *varargs: Union[jax.Array, np.ndarray, Quantity], + axis: Union[int, Sequence[int], None] = None, + edge_order: Union[int, None] = None, +) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + if edge_order is not None: + raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") + + if len(varargs) == 0: + if isinstance(f, Quantity) and not is_unitless(f): + return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit) + else: + return jnp.gradient(f) + elif len(varargs) == 1: + unit = get_unit(f) / get_unit(varargs[0]) + if unit is None or unit == DIMENSIONLESS: + return jnp.gradient(f, varargs[0], axis=axis) + else: + return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + else: + unit_list = [get_unit(f) / get_unit(v) for v in varargs] + f = f.value if isinstance(f, Quantity) else f + varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] + result_list = jnp.gradient(f, *varargs, axis=axis) + return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + + +@set_module_as('braincore.math') +def intersect1d( + ar1: Union[jax.Array, np.ndarray], + ar2: Union[jax.Array, np.ndarray], + assume_unique: bool = False, + return_indices: bool = False +) -> Union[jax.Array, Quantity, tuple[jax.Array | Quantity, jax.Array, jax.Array]]: + fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') + unit = None + if isinstance(ar1, Quantity): + unit = ar1.unit + ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 + ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 + result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + if return_indices: + if unit is not None: + return (Quantity(result[0], unit=unit), result[1], result[2]) + else: + return result + else: + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +nan_to_num = wrap_math_funcs_keep_unit_unary(jnp.nan_to_num) +nanargmax = _compatible_with_quantity(jnp.nanargmax, return_quantity=False) +nanargmin = _compatible_with_quantity(jnp.nanargmin, return_quantity=False) + +rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) +tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py new file mode 100644 index 0000000..95e9635 --- /dev/null +++ b/brainunit/math/_compat_numpy_test.py @@ -0,0 +1,2270 @@ +import unittest + +import jax.numpy as jnp +import pytest + +import braincore as bc +import brainunit.math as bm +import brainunit as U +from brainunit import DimensionMismatchError +from brainunit._base import Quantity +from brainunit._unit_shortcuts import ms, mV + +bc.environ.set(precision=64) + + +def assert_quantity(q, values, unit): + values = jnp.asarray(values) + if isinstance(q, Quantity): + assert q.unit == unit.unit, f"Unit mismatch: {q.unit} != {unit}" + assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" + else: + assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" + + +class TestArrayCreation(unittest.TestCase): + + def test_full(self): + result = bm.full(3, 4) + self.assertEqual(result.shape, (3,)) + self.assertTrue(jnp.all(result == 4)) + + def test_full_like(self): + array = jnp.array([1, 2, 3]) + result = bm.full_like(array, 4) + self.assertEqual(result.shape, array.shape) + self.assertTrue(jnp.all(result == 4)) + + def test_eye(self): + result = bm.eye(3) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.eye(3))) + + def test_identity(self): + result = bm.identity(3) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.identity(3))) + + def test_diag(self): + array = jnp.array([1, 2, 3]) + result = bm.diag(array) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.diag(array))) + + def test_tri(self): + result = bm.tri(3) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.tri(3))) + + def test_tril(self): + array = jnp.ones((3, 3)) + result = bm.tril(array) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.tril(array))) + + def test_triu(self): + array = jnp.ones((3, 3)) + result = bm.triu(array) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.triu(array))) + + def test_empty(self): + result = bm.empty((2, 2)) + self.assertEqual(result.shape, (2, 2)) + + def test_empty_like(self): + array = jnp.array([1, 2, 3]) + result = bm.empty_like(array) + self.assertEqual(result.shape, array.shape) + + def test_ones(self): + result = bm.ones((2, 2)) + self.assertEqual(result.shape, (2, 2)) + self.assertTrue(jnp.all(result == 1)) + + def test_ones_like(self): + array = jnp.array([1, 2, 3]) + result = bm.ones_like(array) + self.assertEqual(result.shape, array.shape) + self.assertTrue(jnp.all(result == 1)) + + def test_zeros(self): + result = bm.zeros((2, 2)) + self.assertEqual(result.shape, (2, 2)) + self.assertTrue(jnp.all(result == 0)) + + def test_zeros_like(self): + array = jnp.array([1, 2, 3]) + result = bm.zeros_like(array) + self.assertEqual(result.shape, array.shape) + self.assertTrue(jnp.all(result == 0)) + + def test_array(self): + result = bm.array([1, 2, 3]) + self.assertEqual(result.shape, (3,)) + self.assertTrue(jnp.all(result == jnp.array([1, 2, 3]))) + + def test_asarray(self): + result = bm.asarray([1, 2, 3]) + self.assertEqual(result.shape, (3,)) + self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3]))) + + def test_arange(self): + result = bm.arange(3) + self.assertEqual(result.shape, (3,)) + self.assertTrue(jnp.all(result == jnp.arange(3))) + + def test_linspace(self): + result = bm.linspace(0, 10, 5) + self.assertEqual(result.shape, (5,)) + self.assertTrue(jnp.all(result == jnp.linspace(0, 10, 5))) + + def test_logspace(self): + result = bm.logspace(0, 2, 5) + self.assertEqual(result.shape, (5,)) + self.assertTrue(jnp.all(result == jnp.logspace(0, 2, 5))) + + def test_fill_diagonal(self): + array = jnp.zeros((3, 3)) + result = bm.fill_diagonal(array, 5, inplace=False) + self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]))) + + def test_array_split(self): + array = jnp.arange(9) + result = bm.array_split(array, 3) + expected = jnp.array_split(array, 3) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + def test_meshgrid(self): + x = jnp.array([1, 2, 3]) + y = jnp.array([4, 5]) + result = bm.meshgrid(x, y) + expected = jnp.meshgrid(x, y) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + def test_vander(self): + array = jnp.array([1, 2, 3]) + result = bm.vander(array) + self.assertEqual(result.shape, (3, 3)) + self.assertTrue(jnp.all(result == jnp.vander(array))) + + +class TestAttributeFunctions(unittest.TestCase): + + def test_ndim(self): + array = jnp.array([[1, 2], [3, 4]]) + self.assertEqual(bm.ndim(array), 2) + + q = [[1, 2], [3, 4]] * ms + self.assertEqual(bm.ndim(q), 2) + + def test_isreal(self): + array = jnp.array([1.0, 2.0]) + self.assertTrue(jnp.all(bm.isreal(array))) + + q = [[1, 2], [3, 4]] * ms + self.assertTrue(jnp.all(bm.isreal(q))) + + def test_isscalar(self): + self.assertTrue(bm.isscalar(1.0)) + self.assertTrue(bm.isscalar(Quantity(1.0))) + + def test_isfinite(self): + array = jnp.array([1.0, jnp.inf]) + self.assertTrue(jnp.all(bm.isfinite(array) == jnp.isfinite(array))) + + q = [1.0, jnp.inf] * ms + self.assertTrue(jnp.all(bm.isfinite(q) == jnp.isfinite(q.value))) + + def test_isinf(self): + array = jnp.array([1.0, jnp.inf]) + self.assertTrue(jnp.all(bm.isinf(array) == jnp.isinf(array))) + + q = [1.0, jnp.inf] * ms + self.assertTrue(jnp.all(bm.isinf(q) == jnp.isinf(q.value))) + + def test_isnan(self): + array = jnp.array([1.0, jnp.nan]) + self.assertTrue(jnp.all(bm.isnan(array) == jnp.isnan(array))) + + q = [1.0, jnp.nan] * ms + self.assertTrue(jnp.all(bm.isnan(q) == jnp.isnan(q.value))) + + def test_shape(self): + array = jnp.array([[1, 2], [3, 4]]) + self.assertEqual(bm.shape(array), (2, 2)) + + q = [[1, 2], [3, 4]] * ms + self.assertEqual(bm.shape(q), (2, 2)) + + def test_size(self): + array = jnp.array([[1, 2], [3, 4]]) + self.assertEqual(bm.size(array), 4) + self.assertEqual(bm.size(array, 1), 2) + + q = [[1, 2], [3, 4]] * ms + self.assertEqual(bm.size(q), 4) + self.assertEqual(bm.size(q, 1), 2) + + +class TestMathFuncsKeepUnitUnary(unittest.TestCase): + + def test_real(self): + complex_array = jnp.array([1 + 2j, 3 + 4j]) + result = bm.real(complex_array) + self.assertTrue(jnp.all(result == jnp.real(complex_array))) + + q = [1 + 2j, 3 + 4j] * U.second + result_q = bm.real(q) + self.assertTrue(jnp.all(result_q == jnp.real(q.value) * U.second)) + + def test_imag(self): + complex_array = jnp.array([1 + 2j, 3 + 4j]) + result = bm.imag(complex_array) + self.assertTrue(jnp.all(result == jnp.imag(complex_array))) + + q = [1 + 2j, 3 + 4j] * U.second + result_q = bm.imag(q) + self.assertTrue(jnp.all(result_q == jnp.imag(q.value) * U.second)) + + def test_conj(self): + complex_array = jnp.array([1 + 2j, 3 + 4j]) + result = bm.conj(complex_array) + self.assertTrue(jnp.all(result == jnp.conj(complex_array))) + + q = [1 + 2j, 3 + 4j] * U.second + result_q = bm.conj(q) + self.assertTrue(jnp.all(result_q == jnp.conj(q.value) * U.second)) + + def test_conjugate(self): + complex_array = jnp.array([1 + 2j, 3 + 4j]) + result = bm.conjugate(complex_array) + self.assertTrue(jnp.all(result == jnp.conjugate(complex_array))) + + q = [1 + 2j, 3 + 4j] * U.second + result_q = bm.conjugate(q) + self.assertTrue(jnp.all(result_q == jnp.conjugate(q.value) * U.second)) + + def test_negative(self): + array = jnp.array([1, 2, 3]) + result = bm.negative(array) + self.assertTrue(jnp.all(result == jnp.negative(array))) + + q = [1, 2, 3] * ms + result_q = bm.negative(q) + expected_q = jnp.negative(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_positive(self): + array = jnp.array([-1, -2, -3]) + result = bm.positive(array) + self.assertTrue(jnp.all(result == jnp.positive(array))) + + q = [-1, -2, -3] * ms + result_q = bm.positive(q) + expected_q = jnp.positive(jnp.array([-1, -2, -3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_abs(self): + array = jnp.array([-1, -2, 3]) + result = bm.abs(array) + self.assertTrue(jnp.all(result == jnp.abs(array))) + + q = [-1, -2, 3] * ms + result_q = bm.abs(q) + expected_q = jnp.abs(jnp.array([-1, -2, -3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_round(self): + array = jnp.array([1.123, 2.567, 3.891]) + result = bm.round(array) + self.assertTrue(jnp.all(result == jnp.round(array))) + + q = [1.123, 2.567, 3.891] * U.second + result_q = bm.round(q) + expected_q = jnp.round(jnp.array([1.123, 2.567, 3.891])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_rint(self): + array = jnp.array([1.5, 2.3, 3.8]) + result = bm.rint(array) + self.assertTrue(jnp.all(result == jnp.rint(array))) + + q = [1.5, 2.3, 3.8] * U.second + result_q = bm.rint(q) + expected_q = jnp.rint(jnp.array([1.5, 2.3, 3.8])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_floor(self): + array = jnp.array([1.5, 2.3, 3.8]) + result = bm.floor(array) + self.assertTrue(jnp.all(result == jnp.floor(array))) + + q = [1.5, 2.3, 3.8] * U.second + result_q = bm.floor(q) + expected_q = jnp.floor(jnp.array([1.5, 2.3, 3.8])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_ceil(self): + array = jnp.array([1.5, 2.3, 3.8]) + result = bm.ceil(array) + self.assertTrue(jnp.all(result == jnp.ceil(array))) + + q = [1.5, 2.3, 3.8] * U.second + result_q = bm.ceil(q) + expected_q = jnp.ceil(jnp.array([1.5, 2.3, 3.8])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_trunc(self): + array = jnp.array([1.5, 2.3, 3.8]) + result = bm.trunc(array) + self.assertTrue(jnp.all(result == jnp.trunc(array))) + + q = [1.5, 2.3, 3.8] * U.second + result_q = bm.trunc(q) + expected_q = jnp.trunc(jnp.array([1.5, 2.3, 3.8])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_fix(self): + array = jnp.array([1.5, 2.3, 3.8]) + result = bm.fix(array) + self.assertTrue(jnp.all(result == jnp.fix(array))) + + q = [1.5, 2.3, 3.8] * U.second + result_q = bm.fix(q) + expected_q = jnp.fix(jnp.array([1.5, 2.3, 3.8])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_sum(self): + array = jnp.array([1, 2, 3]) + result = bm.sum(array) + self.assertTrue(result == jnp.sum(array)) + + q = [1, 2, 3] * ms + result_q = bm.sum(q) + expected_q = jnp.sum(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nancumsum(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nancumsum(array) + self.assertTrue(jnp.all(result == jnp.nancumsum(array))) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nancumsum(q) + expected_q = jnp.nancumsum(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nansum(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nansum(array) + self.assertTrue(result == jnp.nansum(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nansum(q) + expected_q = jnp.nansum(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_cumsum(self): + array = jnp.array([1, 2, 3]) + result = bm.cumsum(array) + self.assertTrue(jnp.all(result == jnp.cumsum(array))) + + q = [1, 2, 3] * ms + result_q = bm.cumsum(q) + expected_q = jnp.cumsum(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_ediff1d(self): + array = jnp.array([1, 2, 3]) + result = bm.ediff1d(array) + self.assertTrue(jnp.all(result == jnp.ediff1d(array))) + + q = [1, 2, 3] * ms + result_q = bm.ediff1d(q) + expected_q = jnp.ediff1d(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_absolute(self): + array = jnp.array([-1, -2, 3]) + result = bm.absolute(array) + self.assertTrue(jnp.all(result == jnp.absolute(array))) + + q = [-1, -2, 3] * ms + result_q = bm.absolute(q) + expected_q = jnp.absolute(jnp.array([-1, -2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_fabs(self): + array = jnp.array([-1, -2, 3]) + result = bm.fabs(array) + self.assertTrue(jnp.all(result == jnp.fabs(array))) + + q = [-1, -2, 3] * ms + result_q = bm.fabs(q) + expected_q = jnp.fabs(jnp.array([-1, -2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_median(self): + array = jnp.array([1, 2, 3]) + result = bm.median(array) + self.assertTrue(result == jnp.median(array)) + + q = [1, 2, 3] * ms + result_q = bm.median(q) + expected_q = jnp.median(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nanmin(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanmin(array) + self.assertTrue(result == jnp.nanmin(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanmin(q) + expected_q = jnp.nanmin(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nanmax(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanmax(array) + self.assertTrue(result == jnp.nanmax(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanmax(q) + expected_q = jnp.nanmax(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_ptp(self): + array = jnp.array([1, 2, 3]) + result = bm.ptp(array) + self.assertTrue(result == jnp.ptp(array)) + + q = [1, 2, 3] * ms + result_q = bm.ptp(q) + expected_q = jnp.ptp(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_average(self): + array = jnp.array([1, 2, 3]) + result = bm.average(array) + self.assertTrue(result == jnp.average(array)) + + q = [1, 2, 3] * ms + result_q = bm.average(q) + expected_q = jnp.average(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_mean(self): + array = jnp.array([1, 2, 3]) + result = bm.mean(array) + self.assertTrue(result == jnp.mean(array)) + + q = [1, 2, 3] * ms + result_q = bm.mean(q) + expected_q = jnp.mean(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_std(self): + array = jnp.array([1, 2, 3]) + result = bm.std(array) + self.assertTrue(result == jnp.std(array)) + + q = [1, 2, 3] * ms + result_q = bm.std(q) + expected_q = jnp.std(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nanmedian(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanmedian(array) + self.assertTrue(result == jnp.nanmedian(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanmedian(q) + expected_q = jnp.nanmedian(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nanmean(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanmean(array) + self.assertTrue(result == jnp.nanmean(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanmean(q) + expected_q = jnp.nanmean(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nanstd(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanstd(array) + self.assertTrue(result == jnp.nanstd(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanstd(q) + expected_q = jnp.nanstd(jnp.array([1, jnp.nan, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_diff(self): + array = jnp.array([1, 2, 3]) + result = bm.diff(array) + self.assertTrue(jnp.all(result == jnp.diff(array))) + + q = [1, 2, 3] * ms + result_q = bm.diff(q) + expected_q = jnp.diff(jnp.array([1, 2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_modf(self): + result = bm.modf(jnp.array([5.5, 7.3])) + expected = jnp.modf(jnp.array([5.5, 7.3])) + self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1])) + + +class TestMathFuncsKeepUnitBinary(unittest.TestCase): + + def test_fmod(self): + result = bm.fmod(jnp.array([5, 7]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.fmod(jnp.array([5, 7]), jnp.array([2, 3])))) + + q1 = [5, 7] * ms + q2 = [2, 3] * ms + result_q = bm.fmod(q1, q2) + expected_q = jnp.fmod(jnp.array([5, 7]), jnp.array([2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_mod(self): + result = bm.mod(jnp.array([5, 7]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.mod(jnp.array([5, 7]), jnp.array([2, 3])))) + + q1 = [5, 7] * ms + q2 = [2, 3] * ms + result_q = bm.mod(q1, q2) + expected_q = jnp.mod(jnp.array([5, 7]), jnp.array([2, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_copysign(self): + result = bm.copysign(jnp.array([-1, 2]), jnp.array([1, -3])) + self.assertTrue(jnp.all(result == jnp.copysign(jnp.array([-1, 2]), jnp.array([1, -3])))) + + q1 = [-1, 2] * ms + q2 = [1, -3] * ms + result_q = bm.copysign(q1, q2) + expected_q = jnp.copysign(jnp.array([-1, 2]), jnp.array([1, -3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_heaviside(self): + result = bm.heaviside(jnp.array([-1, 2]), jnp.array([0.5, 0.5])) + self.assertTrue(jnp.all(result == jnp.heaviside(jnp.array([-1, 2]), jnp.array([0.5, 0.5])))) + + def test_maximum(self): + result = bm.maximum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) + self.assertTrue(jnp.all(result == jnp.maximum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])))) + + q1 = [1, 3, 2] * ms + q2 = [2, 1, 3] * ms + result_q = bm.maximum(q1, q2) + expected_q = jnp.maximum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_minimum(self): + result = bm.minimum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) + self.assertTrue(jnp.all(result == jnp.minimum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])))) + + q1 = [1, 3, 2] * ms + q2 = [2, 1, 3] * ms + result_q = bm.minimum(q1, q2) + expected_q = jnp.minimum(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_fmax(self): + result = bm.fmax(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) + self.assertTrue(jnp.all(result == jnp.fmax(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])))) + + q1 = [1, 3, 2] * ms + q2 = [2, 1, 3] * ms + result_q = bm.fmax(q1, q2) + expected_q = jnp.fmax(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_fmin(self): + result = bm.fmin(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) + self.assertTrue(jnp.all(result == jnp.fmin(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])))) + + q1 = [1, 3, 2] * ms + q2 = [2, 1, 3] * ms + result_q = bm.fmin(q1, q2) + expected_q = jnp.fmin(jnp.array([1, 3, 2]), jnp.array([2, 1, 3])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_lcm(self): + result = bm.lcm(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) + self.assertTrue(jnp.all(result == jnp.lcm(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])))) + + q1 = [4, 5, 6] * U.second + q2 = [2, 3, 4] * U.second + q1 = q1.astype(jnp.int64) + q2 = q2.astype(jnp.int64) + result_q = bm.lcm(q1, q2) + expected_q = jnp.lcm(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_gcd(self): + result = bm.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) + self.assertTrue(jnp.all(result == jnp.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])))) + + q1 = [4, 5, 6] * U.second + q2 = [2, 3, 4] * U.second + q1 = q1.astype(jnp.int64) + q2 = q2.astype(jnp.int64) + result_q = bm.gcd(q1, q2) + expected_q = jnp.gcd(jnp.array([4, 5, 6]), jnp.array([2, 3, 4])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + +class TestMathFuncsKeepUnitUnary(unittest.TestCase): + + def test_interp(self): + x = jnp.array([1, 2, 3]) + xp = jnp.array([0, 1, 2, 3, 4]) + fp = jnp.array([0, 1, 2, 3, 4]) + result = bm.interp(x, xp, fp) + self.assertTrue(jnp.all(result == jnp.interp(x, xp, fp))) + + x = [1, 2, 3] * U.second + xp = [0, 1, 2, 3, 4] * U.second + fp = [0, 1, 2, 3, 4] * U.second + result_q = bm.interp(x, xp, fp) + expected_q = jnp.interp(jnp.array([1, 2, 3]), jnp.array([0, 1, 2, 3, 4]), jnp.array([0, 1, 2, 3, 4])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_clip(self): + array = jnp.array([1, 2, 3, 4, 5]) + result = bm.clip(array, 2, 4) + self.assertTrue(jnp.all(result == jnp.clip(array, 2, 4))) + + q = [1, 2, 3, 4, 5] * ms + result_q = bm.clip(q, 2 * ms, 4 * ms) + expected_q = jnp.clip(jnp.array([1, 2, 3, 4, 5]), 2, 4) * ms + assert_quantity(result_q, expected_q.value, ms) + + +class TestMathFuncsMatchUnitBinary(unittest.TestCase): + + def test_add(self): + result = bm.add(jnp.array([1, 2]), jnp.array([3, 4])) + self.assertTrue(jnp.all(result == jnp.add(jnp.array([1, 2]), jnp.array([3, 4])))) + + q1 = [1, 2] * ms + q2 = [3, 4] * ms + result_q = bm.add(q1, q2) + expected_q = jnp.add(jnp.array([1, 2]), jnp.array([3, 4])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_subtract(self): + result = bm.subtract(jnp.array([5, 6]), jnp.array([3, 2])) + self.assertTrue(jnp.all(result == jnp.subtract(jnp.array([5, 6]), jnp.array([3, 2])))) + + q1 = [5, 6] * ms + q2 = [3, 2] * ms + result_q = bm.subtract(q1, q2) + expected_q = jnp.subtract(jnp.array([5, 6]), jnp.array([3, 2])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_nextafter(self): + result = bm.nextafter(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.nextafter(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + +class TestMathFuncsChangeUnitUnary(unittest.TestCase): + + def test_reciprocal(self): + array = jnp.array([1.0, 2.0, 0.5]) + result = bm.reciprocal(array) + self.assertTrue(jnp.all(result == jnp.reciprocal(array))) + + q = [1.0, 2.0, 0.5] * U.second + result_q = bm.reciprocal(q) + expected_q = jnp.reciprocal(jnp.array([1.0, 2.0, 0.5])) * (1 / U.second) + assert_quantity(result_q, expected_q.value, 1 / U.second) + + def test_prod(self): + array = jnp.array([1, 2, 3]) + result = bm.prod(array) + self.assertTrue(result == jnp.prod(array)) + + q = [1, 2, 3] * ms + result_q = bm.prod(q) + expected_q = jnp.prod(jnp.array([1, 2, 3])) * (ms ** 3) + assert_quantity(result_q, expected_q.value, ms ** 3) + + def test_nanprod(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanprod(array) + self.assertTrue(result == jnp.nanprod(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanprod(q) + expected_q = jnp.nanprod(jnp.array([1, jnp.nan, 3])) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + def test_cumprod(self): + array = jnp.array([1, 2, 3]) + result = bm.cumprod(array) + self.assertTrue(jnp.all(result == jnp.cumprod(array))) + + q = [1, 2, 3] * U.second + result_q = bm.cumprod(q) + expected_q = jnp.cumprod(jnp.array([1, 2, 3])) * (U.second ** 3) + assert_quantity(result_q, expected_q.value, U.second ** 3) + + def test_nancumprod(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nancumprod(array) + self.assertTrue(jnp.all(result == jnp.nancumprod(array))) + + q = [1, jnp.nan, 3] * U.second + result_q = bm.nancumprod(q) + expected_q = jnp.nancumprod(jnp.array([1, jnp.nan, 3])) * (U.second ** 2) + assert_quantity(result_q, expected_q.value, U.second ** 2) + + def test_var(self): + array = jnp.array([1, 2, 3]) + result = bm.var(array) + self.assertTrue(result == jnp.var(array)) + + q = [1, 2, 3] * ms + result_q = bm.var(q) + expected_q = jnp.var(jnp.array([1, 2, 3])) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + def test_nanvar(self): + array = jnp.array([1, jnp.nan, 3]) + result = bm.nanvar(array) + self.assertTrue(result == jnp.nanvar(array)) + + q = [1, jnp.nan, 3] * ms + result_q = bm.nanvar(q) + expected_q = jnp.nanvar(jnp.array([1, jnp.nan, 3])) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + def test_frexp(self): + result = bm.frexp(jnp.array([1.0, 2.0])) + expected = jnp.frexp(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1])) + + def test_sqrt(self): + result = bm.sqrt(jnp.array([1.0, 4.0])) + self.assertTrue(jnp.all(result == jnp.sqrt(jnp.array([1.0, 4.0])))) + + q = [1.0, 4.0] * (ms ** 2) + result_q = bm.sqrt(q) + expected_q = jnp.sqrt(jnp.array([1.0, 4.0])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_cbrt(self): + result = bm.cbrt(jnp.array([1.0, 8.0])) + self.assertTrue(jnp.all(result == jnp.cbrt(jnp.array([1.0, 8.0])))) + + q = [1.0, 8.0] * (ms ** 3) + result_q = bm.cbrt(q) + expected_q = jnp.cbrt(jnp.array([1.0, 8.0])) * ms + assert_quantity(result_q, expected_q.value, ms) + + def test_square(self): + result = bm.square(jnp.array([2.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.square(jnp.array([2.0, 3.0])))) + + q = [2.0, 3.0] * ms + result_q = bm.square(q) + expected_q = jnp.square(jnp.array([2.0, 3.0])) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + +class TestMathFuncsChangeUnitBinary(unittest.TestCase): + + def test_multiply(self): + result = bm.multiply(jnp.array([1, 2]), jnp.array([3, 4])) + self.assertTrue(jnp.all(result == jnp.multiply(jnp.array([1, 2]), jnp.array([3, 4])))) + + q1 = [1, 2] * ms + q2 = [3, 4] * mV + result_q = bm.multiply(q1, q2) + expected_q = jnp.multiply(jnp.array([1, 2]), jnp.array([3, 4])) * (ms * mV) + assert_quantity(result_q, expected_q.value, ms * mV) + + def test_divide(self): + result = bm.divide(jnp.array([5, 6]), jnp.array([3, 2])) + self.assertTrue(jnp.all(result == jnp.divide(jnp.array([5, 6]), jnp.array([3, 2])))) + + q1 = [5, 6] * ms + q2 = [3, 2] * mV + result_q = bm.divide(q1, q2) + expected_q = jnp.divide(jnp.array([5, 6]), jnp.array([3, 2])) * (ms / mV) + assert_quantity(result_q, expected_q.value, ms / mV) + + def test_power(self): + result = bm.power(jnp.array([1, 2]), jnp.array([3, 2])) + self.assertTrue(jnp.all(result == jnp.power(jnp.array([1, 2]), jnp.array([3, 2])))) + + q1 = [1, 2] * ms + result_q = bm.power(q1, 2) + expected_q = jnp.power(jnp.array([1, 2]), 2) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + def test_cross(self): + result = bm.cross(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) + self.assertTrue(jnp.all(result == jnp.cross(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])))) + + def test_ldexp(self): + result = bm.ldexp(jnp.array([1.0, 2.0]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.ldexp(jnp.array([1.0, 2.0]), jnp.array([2, 3])))) + + def test_true_divide(self): + result = bm.true_divide(jnp.array([5, 6]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.true_divide(jnp.array([5, 6]), jnp.array([2, 3])))) + + q1 = [5, 6] * ms + q2 = [2, 3] * mV + result_q = bm.true_divide(q1, q2) + expected_q = jnp.true_divide(jnp.array([5, 6]), jnp.array([2, 3])) * (ms / mV) + assert_quantity(result_q, expected_q.value, ms / mV) + + def test_floor_divide(self): + result = bm.floor_divide(jnp.array([5, 6]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.floor_divide(jnp.array([5, 6]), jnp.array([2, 3])))) + + q1 = [5, 6] * ms + q2 = [2, 3] * mV + result_q = bm.floor_divide(q1, q2) + expected_q = jnp.floor_divide(jnp.array([5, 6]), jnp.array([2, 3])) * (ms / mV) + assert_quantity(result_q, expected_q.value, ms / mV) + + def test_float_power(self): + result = bm.float_power(jnp.array([2, 3]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.float_power(jnp.array([2, 3]), jnp.array([2, 3])))) + + q1 = [2, 3] * ms + result_q = bm.float_power(q1, 2) + expected_q = jnp.float_power(jnp.array([2, 3]), 2) * (ms ** 2) + assert_quantity(result_q, expected_q.value, ms ** 2) + + def test_divmod(self): + result = bm.divmod(jnp.array([5, 6]), jnp.array([2, 3])) + expected = jnp.divmod(jnp.array([5, 6]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result[0] == expected[0]) and jnp.all(result[1] == expected[1])) + + def test_remainder(self): + result = bm.remainder(jnp.array([5, 7]), jnp.array([2, 3])) + self.assertTrue(jnp.all(result == jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])))) + + q1 = [5, 7] * (U.second ** 2) + q2 = [2, 3] * U.second + result_q = bm.remainder(q1, q2) + expected_q = jnp.remainder(jnp.array([5, 7]), jnp.array([2, 3])) * U.second + assert_quantity(result_q, expected_q.value, U.second) + + def test_convolve(self): + result = bm.convolve(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) + self.assertTrue(jnp.all(result == jnp.convolve(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])))) + + +class TestMathFuncsOnlyAcceptUnitlessUnary(unittest.TestCase): + + def test_exp(self): + result = bm.exp(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.exp(jnp.array([1.0, 2.0])))) + + result = bm.exp(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.exp(jnp.array([1.0, 2.0])))) + + def test_exp2(self): + result = bm.exp2(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.exp2(jnp.array([1.0, 2.0])))) + + result = bm.exp2(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.exp2(jnp.array([1.0, 2.0])))) + + def test_expm1(self): + result = bm.expm1(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.expm1(jnp.array([1.0, 2.0])))) + + result = bm.expm1(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.expm1(jnp.array([1.0, 2.0])))) + + def test_log(self): + result = bm.log(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.log(jnp.array([1.0, 2.0])))) + + result = bm.log(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.log(jnp.array([1.0, 2.0])))) + + def test_log10(self): + result = bm.log10(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.log10(jnp.array([1.0, 2.0])))) + + result = bm.log10(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.log10(jnp.array([1.0, 2.0])))) + + def test_log1p(self): + result = bm.log1p(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.log1p(jnp.array([1.0, 2.0])))) + + result = bm.log1p(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.log1p(jnp.array([1.0, 2.0])))) + + def test_log2(self): + result = bm.log2(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.log2(jnp.array([1.0, 2.0])))) + + result = bm.log2(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.log2(jnp.array([1.0, 2.0])))) + + def test_arccos(self): + result = bm.arccos(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.arccos(jnp.array([0.5, 1.0])))) + + result = bm.arccos(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.arccos(jnp.array([0.5, 1.0])))) + + def test_arccosh(self): + result = bm.arccosh(jnp.array([1.0, 2.0])) + self.assertTrue(jnp.all(result == jnp.arccosh(jnp.array([1.0, 2.0])))) + + result = bm.arccosh(Quantity(jnp.array([1.0, 2.0]))) + self.assertTrue(jnp.all(result == jnp.arccosh(jnp.array([1.0, 2.0])))) + + def test_arcsin(self): + result = bm.arcsin(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.arcsin(jnp.array([0.5, 1.0])))) + + result = bm.arcsin(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.arcsin(jnp.array([0.5, 1.0])))) + + def test_arcsinh(self): + result = bm.arcsinh(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.arcsinh(jnp.array([0.5, 1.0])))) + + result = bm.arcsinh(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.arcsinh(jnp.array([0.5, 1.0])))) + + def test_arctan(self): + result = bm.arctan(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.arctan(jnp.array([0.5, 1.0])))) + + result = bm.arctan(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.arctan(jnp.array([0.5, 1.0])))) + + def test_arctanh(self): + result = bm.arctanh(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.arctanh(jnp.array([0.5, 1.0])))) + + result = bm.arctanh(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.arctanh(jnp.array([0.5, 1.0])))) + + def test_cos(self): + result = bm.cos(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.cos(jnp.array([0.5, 1.0])))) + + result = bm.cos(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.cos(jnp.array([0.5, 1.0])))) + + def test_cosh(self): + result = bm.cosh(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.cosh(jnp.array([0.5, 1.0])))) + + result = bm.cosh(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.cosh(jnp.array([0.5, 1.0])))) + + def test_sin(self): + result = bm.sin(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.sin(jnp.array([0.5, 1.0])))) + + result = bm.sin(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.sin(jnp.array([0.5, 1.0])))) + + def test_sinc(self): + result = bm.sinc(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.sinc(jnp.array([0.5, 1.0])))) + + result = bm.sinc(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.sinc(jnp.array([0.5, 1.0])))) + + def test_sinh(self): + result = bm.sinh(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.sinh(jnp.array([0.5, 1.0])))) + + result = bm.sinh(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.sinh(jnp.array([0.5, 1.0])))) + + def test_tan(self): + result = bm.tan(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.tan(jnp.array([0.5, 1.0])))) + + result = bm.tan(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.tan(jnp.array([0.5, 1.0])))) + + def test_tanh(self): + result = bm.tanh(jnp.array([0.5, 1.0])) + self.assertTrue(jnp.all(result == jnp.tanh(jnp.array([0.5, 1.0])))) + + result = bm.tanh(Quantity(jnp.array([0.5, 1.0]))) + self.assertTrue(jnp.all(result == jnp.tanh(jnp.array([0.5, 1.0])))) + + def test_deg2rad(self): + result = bm.deg2rad(jnp.array([90.0, 180.0])) + self.assertTrue(jnp.all(result == jnp.deg2rad(jnp.array([90.0, 180.0])))) + + result = bm.deg2rad(Quantity(jnp.array([90.0, 180.0]))) + self.assertTrue(jnp.all(result == jnp.deg2rad(jnp.array([90.0, 180.0])))) + + def test_rad2deg(self): + result = bm.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])) + self.assertTrue(jnp.all(result == jnp.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])))) + + result = bm.rad2deg(Quantity(jnp.array([jnp.pi / 2, jnp.pi]))) + self.assertTrue(jnp.all(result == jnp.rad2deg(jnp.array([jnp.pi / 2, jnp.pi])))) + + def test_degrees(self): + result = bm.degrees(jnp.array([jnp.pi / 2, jnp.pi])) + self.assertTrue(jnp.all(result == jnp.degrees(jnp.array([jnp.pi / 2, jnp.pi])))) + + result = bm.degrees(Quantity(jnp.array([jnp.pi / 2, jnp.pi]))) + self.assertTrue(jnp.all(result == jnp.degrees(jnp.array([jnp.pi / 2, jnp.pi])))) + + def test_radians(self): + result = bm.radians(jnp.array([90.0, 180.0])) + self.assertTrue(jnp.all(result == jnp.radians(jnp.array([90.0, 180.0])))) + + result = bm.radians(Quantity(jnp.array([90.0, 180.0]))) + self.assertTrue(jnp.all(result == jnp.radians(jnp.array([90.0, 180.0])))) + + def test_angle(self): + result = bm.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])) + self.assertTrue(jnp.all(result == jnp.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])))) + + result = bm.angle(Quantity(jnp.array([1.0 + 1.0j, 1.0 - 1.0j]))) + self.assertTrue(jnp.all(result == jnp.angle(jnp.array([1.0 + 1.0j, 1.0 - 1.0j])))) + + def test_percentile(self): + array = jnp.array([1, 2, 3, 4]) + result = bm.percentile(array, 50) + self.assertTrue(result == jnp.percentile(array, 50)) + + def test_nanpercentile(self): + array = jnp.array([1, jnp.nan, 3, 4]) + result = bm.nanpercentile(array, 50) + self.assertTrue(result == jnp.nanpercentile(array, 50)) + + def test_quantile(self): + array = jnp.array([1, 2, 3, 4]) + result = bm.quantile(array, 0.5) + self.assertTrue(result == jnp.quantile(array, 0.5)) + + def test_nanquantile(self): + array = jnp.array([1, jnp.nan, 3, 4]) + result = bm.nanquantile(array, 0.5) + self.assertTrue(result == jnp.nanquantile(array, 0.5)) + + +class TestMathFuncsOnlyAcceptUnitlessBinary(unittest.TestCase): + + def test_hypot(self): + result = bm.hypot(jnp.array([3.0, 4.0]), jnp.array([4.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.hypot(jnp.array([3.0, 4.0]), jnp.array([4.0, 3.0])))) + + result = bm.hypot(Quantity(jnp.array([3.0, 4.0])), Quantity(jnp.array([4.0, 3.0]))) + self.assertTrue(jnp.all(result == jnp.hypot(jnp.array([3.0, 4.0]), jnp.array([4.0, 3.0])))) + + def test_arctan2(self): + result = bm.arctan2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.arctan2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + result = bm.arctan2(Quantity(jnp.array([1.0, 2.0])), Quantity(jnp.array([2.0, 3.0]))) + self.assertTrue(jnp.all(result == jnp.arctan2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + def test_logaddexp(self): + result = bm.logaddexp(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.logaddexp(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + result = bm.logaddexp(Quantity(jnp.array([1.0, 2.0])), Quantity(jnp.array([2.0, 3.0]))) + self.assertTrue(jnp.all(result == jnp.logaddexp(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + def test_logaddexp2(self): + result = bm.logaddexp2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])) + self.assertTrue(jnp.all(result == jnp.logaddexp2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + result = bm.logaddexp2(Quantity(jnp.array([1.0, 2.0])), Quantity(jnp.array([2.0, 3.0]))) + self.assertTrue(jnp.all(result == jnp.logaddexp2(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0])))) + + +class TestMathFuncsRemoveUnitUnary(unittest.TestCase): + + def test_signbit(self): + array = jnp.array([-1.0, 2.0]) + result = bm.signbit(array) + self.assertTrue(jnp.all(result == jnp.signbit(array))) + + q = [-1.0, 2.0] * U.second + result_q = bm.signbit(q) + expected_q = jnp.signbit(jnp.array([-1.0, 2.0])) + assert_quantity(result_q, expected_q, None) + + def test_sign(self): + array = jnp.array([-1.0, 2.0]) + result = bm.sign(array) + self.assertTrue(jnp.all(result == jnp.sign(array))) + + q = [-1.0, 2.0] * U.second + result_q = bm.sign(q) + expected_q = jnp.sign(jnp.array([-1.0, 2.0])) + assert_quantity(result_q, expected_q, None) + + def test_histogram(self): + array = jnp.array([1, 2, 1]) + result, _ = bm.histogram(array) + expected, _ = jnp.histogram(array) + self.assertTrue(jnp.all(result == expected)) + + q = [1, 2, 1] * U.second + result_q, _ = bm.histogram(q) + expected_q, _ = jnp.histogram(jnp.array([1, 2, 1])) + assert_quantity(result_q, expected_q, None) + + def test_bincount(self): + array = jnp.array([1, 1, 2, 2, 2, 3]) + result = bm.bincount(array) + self.assertTrue(jnp.all(result == jnp.bincount(array))) + + q = [1, 1, 2, 2, 2, 3] * U.second + q = q.astype(jnp.int64) + result_q = bm.bincount(q) + expected_q = jnp.bincount(jnp.array([1, 1, 2, 2, 2, 3])) + assert_quantity(result_q, expected_q, None) + + +class TestMathFuncsRemoveUnitBinary(unittest.TestCase): + + def test_corrcoef(self): + x = jnp.array([1, 2, 3]) + y = jnp.array([4, 5, 6]) + result = bm.corrcoef(x, y) + self.assertTrue(jnp.all(result == jnp.corrcoef(x, y))) + + x = [1, 2, 3] * U.second + y = [4, 5, 6] * U.second + result = bm.corrcoef(x, y) + expected = jnp.corrcoef(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) + assert_quantity(result, expected, None) + + def test_correlate(self): + x = jnp.array([1, 2, 3]) + y = jnp.array([0, 1, 0.5]) + result = bm.correlate(x, y) + self.assertTrue(jnp.all(result == jnp.correlate(x, y))) + + x = [1, 2, 3] * U.second + y = [0, 1, 0.5] * U.second + result = bm.correlate(x, y) + expected = jnp.correlate(jnp.array([1, 2, 3]), jnp.array([0, 1, 0.5])) + assert_quantity(result, expected, None) + + def test_cov(self): + x = jnp.array([1, 2, 3]) + y = jnp.array([4, 5, 6]) + result = bm.cov(x, y) + self.assertTrue(jnp.all(result == jnp.cov(x, y))) + + x = [1, 2, 3] * U.second + y = [4, 5, 6] * U.second + result = bm.cov(x, y) + expected = jnp.cov(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) + assert_quantity(result, expected, None) + + def test_digitize(self): + array = jnp.array([0.2, 6.4, 3.0, 1.6]) + bins = jnp.array([0.0, 1.0, 2.5, 4.0, 10.0]) + result = bm.digitize(array, bins) + self.assertTrue(jnp.all(result == jnp.digitize(array, bins))) + + array = [0.2, 6.4, 3.0, 1.6] * U.second + bins = [0.0, 1.0, 2.5, 4.0, 10.0] * U.second + result = bm.digitize(array, bins) + expected = jnp.digitize(jnp.array([0.2, 6.4, 3.0, 1.6]), jnp.array([0.0, 1.0, 2.5, 4.0, 10.0])) + assert_quantity(result, expected, None) + + +class TestArrayManipulation(unittest.TestCase): + + def test_reshape(self): + array = jnp.array([1, 2, 3, 4]) + result = bm.reshape(array, (2, 2)) + self.assertTrue(jnp.all(result == jnp.reshape(array, (2, 2)))) + + q = [1, 2, 3, 4] * U.second + result_q = bm.reshape(q, (2, 2)) + expected_q = jnp.reshape(jnp.array([1, 2, 3, 4]), (2, 2)) + assert_quantity(result_q, expected_q, U.second) + + def test_moveaxis(self): + array = jnp.zeros((3, 4, 5)) + result = bm.moveaxis(array, 0, -1) + self.assertTrue(jnp.all(result == jnp.moveaxis(array, 0, -1))) + + q = jnp.zeros((3, 4, 5)) * U.second + result_q = bm.moveaxis(q, 0, -1) + expected_q = jnp.moveaxis(jnp.zeros((3, 4, 5)), 0, -1) + assert_quantity(result_q, expected_q, U.second) + + def test_transpose(self): + array = jnp.ones((2, 3)) + result = bm.transpose(array) + self.assertTrue(jnp.all(result == jnp.transpose(array))) + + q = jnp.ones((2, 3)) * U.second + result_q = bm.transpose(q) + expected_q = jnp.transpose(jnp.ones((2, 3))) + assert_quantity(result_q, expected_q, U.second) + + def test_swapaxes(self): + array = jnp.zeros((3, 4, 5)) + result = bm.swapaxes(array, 0, 2) + self.assertTrue(jnp.all(result == jnp.swapaxes(array, 0, 2))) + + q = jnp.zeros((3, 4, 5)) * U.second + result_q = bm.swapaxes(q, 0, 2) + expected_q = jnp.swapaxes(jnp.zeros((3, 4, 5)), 0, 2) + assert_quantity(result_q, expected_q, U.second) + + def test_row_stack(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([4, 5, 6]) + result = bm.row_stack((a, b)) + self.assertTrue(jnp.all(result == jnp.vstack((a, b)))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5, 6] * U.second + result_q = bm.row_stack((q1, q2)) + expected_q = jnp.vstack((jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))) + assert_quantity(result_q, expected_q, U.second) + + def test_concatenate(self): + a = jnp.array([[1, 2], [3, 4]]) + b = jnp.array([[5, 6]]) + result = bm.concatenate((a, b), axis=0) + self.assertTrue(jnp.all(result == jnp.concatenate((a, b), axis=0))) + + q1 = [[1, 2], [3, 4]] * U.second + q2 = [[5, 6]] * U.second + result_q = bm.concatenate((q1, q2), axis=0) + expected_q = jnp.concatenate((jnp.array([[1, 2], [3, 4]]), jnp.array([[5, 6]])), axis=0) + assert_quantity(result_q, expected_q, U.second) + + def test_stack(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([4, 5, 6]) + result = bm.stack((a, b), axis=1) + self.assertTrue(jnp.all(result == jnp.stack((a, b), axis=1))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5, 6] * U.second + result_q = bm.stack((q1, q2), axis=1) + expected_q = jnp.stack((jnp.array([1, 2, 3]), jnp.array([4, 5, 6])), axis=1) + assert_quantity(result_q, expected_q, U.second) + + def test_vstack(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([4, 5, 6]) + result = bm.vstack((a, b)) + self.assertTrue(jnp.all(result == jnp.vstack((a, b)))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5, 6] * U.second + result_q = bm.vstack((q1, q2)) + expected_q = jnp.vstack((jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))) + assert_quantity(result_q, expected_q, U.second) + + def test_hstack(self): + a = jnp.array((1, 2, 3)) + b = jnp.array((4, 5, 6)) + result = bm.hstack((a, b)) + self.assertTrue(jnp.all(result == jnp.hstack((a, b)))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5, 6] * U.second + result_q = bm.hstack((q1, q2)) + expected_q = jnp.hstack((jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))) + assert_quantity(result_q, expected_q, U.second) + + def test_dstack(self): + a = jnp.array([[1], [2], [3]]) + b = jnp.array([[4], [5], [6]]) + result = bm.dstack((a, b)) + self.assertTrue(jnp.all(result == jnp.dstack((a, b)))) + + q1 = [[1], [2], [3]] * U.second + q2 = [[4], [5], [6]] * U.second + result_q = bm.dstack((q1, q2)) + expected_q = jnp.dstack((jnp.array([[1], [2], [3]]), jnp.array([[4], [5], [6]]))) + assert_quantity(result_q, expected_q, U.second) + + def test_column_stack(self): + a = jnp.array((1, 2, 3)) + b = jnp.array((4, 5, 6)) + result = bm.column_stack((a, b)) + self.assertTrue(jnp.all(result == jnp.column_stack((a, b)))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5, 6] * U.second + result_q = bm.column_stack((q1, q2)) + expected_q = jnp.column_stack((jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))) + assert_quantity(result_q, expected_q, U.second) + + def test_split(self): + array = jnp.arange(9) + result = bm.split(array, 3) + expected = jnp.split(array, 3) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + q = jnp.arange(9) * U.second + result_q = bm.split(q, 3) + expected_q = jnp.split(jnp.arange(9), 3) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, ms) + + def test_dsplit(self): + array = jnp.arange(16.0).reshape(2, 2, 4) + result = bm.dsplit(array, 2) + expected = jnp.dsplit(array, 2) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + q = jnp.arange(16.0).reshape(2, 2, 4) * U.second + result_q = bm.dsplit(q, 2) + expected_q = jnp.dsplit(jnp.arange(16.0).reshape(2, 2, 4), 2) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, U.second) + + def test_hsplit(self): + array = jnp.arange(16.0).reshape(4, 4) + result = bm.hsplit(array, 2) + expected = jnp.hsplit(array, 2) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + q = jnp.arange(16.0).reshape(4, 4) * U.second + result_q = bm.hsplit(q, 2) + expected_q = jnp.hsplit(jnp.arange(16.0).reshape(4, 4), 2) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, U.second) + + def test_vsplit(self): + array = jnp.arange(16.0).reshape(4, 4) + result = bm.vsplit(array, 2) + expected = jnp.vsplit(array, 2) + for r, e in zip(result, expected): + self.assertTrue(jnp.all(r == e)) + + q = jnp.arange(16.0).reshape(4, 4) * U.second + result_q = bm.vsplit(q, 2) + expected_q = jnp.vsplit(jnp.arange(16.0).reshape(4, 4), 2) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, U.second) + + def test_tile(self): + array = jnp.array([0, 1, 2]) + result = bm.tile(array, 2) + self.assertTrue(jnp.all(result == jnp.tile(array, 2))) + + q = jnp.array([0, 1, 2]) * U.second + result_q = bm.tile(q, 2) + expected_q = jnp.tile(jnp.array([0, 1, 2]), 2) + assert_quantity(result_q, expected_q, U.second) + + def test_repeat(self): + array = jnp.array([0, 1, 2]) + result = bm.repeat(array, 2) + self.assertTrue(jnp.all(result == jnp.repeat(array, 2))) + + q = [0, 1, 2] * U.second + result_q = bm.repeat(q, 2) + expected_q = jnp.repeat(jnp.array([0, 1, 2]), 2) + assert_quantity(result_q, expected_q, U.second) + + def test_unique(self): + array = jnp.array([0, 1, 2, 1, 0]) + result = bm.unique(array) + self.assertTrue(jnp.all(result == jnp.unique(array))) + + q = [0, 1, 2, 1, 0] * U.second + result_q = bm.unique(q) + expected_q = jnp.unique(jnp.array([0, 1, 2, 1, 0])) + assert_quantity(result_q, expected_q, U.second) + + def test_append(self): + array = jnp.array([0, 1, 2]) + result = bm.append(array, 3) + self.assertTrue(jnp.all(result == jnp.append(array, 3))) + + q = [0, 1, 2] * U.second + result_q = bm.append(q, 3) + expected_q = jnp.append(jnp.array([0, 1, 2]), 3) + assert_quantity(result_q, expected_q, U.second) + + def test_flip(self): + array = jnp.array([0, 1, 2]) + result = bm.flip(array) + self.assertTrue(jnp.all(result == jnp.flip(array))) + + q = [0, 1, 2] * U.second + result_q = bm.flip(q) + expected_q = jnp.flip(jnp.array([0, 1, 2])) + assert_quantity(result_q, expected_q, U.second) + + def test_fliplr(self): + array = jnp.array([[0, 1, 2], [3, 4, 5]]) + result = bm.fliplr(array) + self.assertTrue(jnp.all(result == jnp.fliplr(array))) + + q = [[0, 1, 2], [3, 4, 5]] * U.second + result_q = bm.fliplr(q) + expected_q = jnp.fliplr(jnp.array([[0, 1, 2], [3, 4, 5]])) + assert_quantity(result_q, expected_q, U.second) + + def test_flipud(self): + array = jnp.array([[0, 1, 2], [3, 4, 5]]) + result = bm.flipud(array) + self.assertTrue(jnp.all(result == jnp.flipud(array))) + + q = [[0, 1, 2], [3, 4, 5]] * U.second + result_q = bm.flipud(q) + expected_q = jnp.flipud(jnp.array([[0, 1, 2], [3, 4, 5]])) + assert_quantity(result_q, expected_q, ms) + + def test_roll(self): + array = jnp.array([0, 1, 2]) + result = bm.roll(array, 1) + self.assertTrue(jnp.all(result == jnp.roll(array, 1))) + + q = [0, 1, 2] * U.second + result_q = bm.roll(q, 1) + expected_q = jnp.roll(jnp.array([0, 1, 2]), 1) + assert_quantity(result_q, expected_q, ms) + + def test_atleast_1d(self): + array = jnp.array(0) + result = bm.atleast_1d(array) + self.assertTrue(jnp.all(result == jnp.atleast_1d(array))) + + q = 0 * U.second + result_q = bm.atleast_1d(q) + expected_q = jnp.atleast_1d(jnp.array(0)) + assert_quantity(result_q, expected_q, U.second) + + def test_atleast_2d(self): + array = jnp.array([0, 1, 2]) + result = bm.atleast_2d(array) + self.assertTrue(jnp.all(result == jnp.atleast_2d(array))) + + q = [0, 1, 2] * U.second + result_q = bm.atleast_2d(q) + expected_q = jnp.atleast_2d(jnp.array([0, 1, 2])) + assert_quantity(result_q, expected_q, U.second) + + def test_atleast_3d(self): + array = jnp.array([[0, 1, 2], [3, 4, 5]]) + result = bm.atleast_3d(array) + self.assertTrue(jnp.all(result == jnp.atleast_3d(array))) + + q = [[0, 1, 2], [3, 4, 5]] * U.second + result_q = bm.atleast_3d(q) + expected_q = jnp.atleast_3d(jnp.array([[0, 1, 2], [3, 4, 5]])) + assert_quantity(result_q, expected_q, U.second) + + def test_expand_dims(self): + array = jnp.array([1, 2, 3]) + result = bm.expand_dims(array, axis=0) + self.assertTrue(jnp.all(result == jnp.expand_dims(array, axis=0))) + + q = [1, 2, 3] * U.second + result_q = bm.expand_dims(q, axis=0) + expected_q = jnp.expand_dims(jnp.array([1, 2, 3]), axis=0) + assert_quantity(result_q, expected_q, U.second) + + def test_squeeze(self): + array = jnp.array([[[0], [1], [2]]]) + result = bm.squeeze(array) + self.assertTrue(jnp.all(result == jnp.squeeze(array))) + + q = [[[0], [1], [2]]] * U.second + result_q = bm.squeeze(q) + expected_q = jnp.squeeze(jnp.array([[[0], [1], [2]]])) + assert_quantity(result_q, expected_q, U.second) + + def test_sort(self): + array = jnp.array([2, 3, 1]) + result = bm.sort(array) + self.assertTrue(jnp.all(result == jnp.sort(array))) + + q = [2, 3, 1] * U.second + result_q = bm.sort(q) + expected_q = jnp.sort(jnp.array([2, 3, 1])) + assert_quantity(result_q, expected_q, U.second) + + def test_max(self): + array = jnp.array([1, 2, 3]) + result = bm.max(array) + self.assertTrue(result == jnp.max(array)) + + q = [1, 2, 3] * U.second + result_q = bm.max(q) + expected_q = jnp.max(jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_min(self): + array = jnp.array([1, 2, 3]) + result = bm.min(array) + self.assertTrue(result == jnp.min(array)) + + q = [1, 2, 3] * U.second + result_q = bm.min(q) + expected_q = jnp.min(jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_amin(self): + array = jnp.array([1, 2, 3]) + result = bm.amin(array) + self.assertTrue(result == jnp.min(array)) + + q = [1, 2, 3] * U.second + result_q = bm.amin(q) + expected_q = jnp.min(jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_amax(self): + array = jnp.array([1, 2, 3]) + result = bm.amax(array) + self.assertTrue(result == jnp.max(array)) + + q = [1, 2, 3] * U.second + result_q = bm.amax(q) + expected_q = jnp.max(jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_block(self): + array = jnp.array([[1, 2], [3, 4]]) + result = bm.block(array) + self.assertTrue(jnp.all(result == jnp.block(array))) + + q = [[1, 2], [3, 4]] * U.second + result_q = bm.block(q) + expected_q = jnp.block(jnp.array([[1, 2], [3, 4]])) + assert_quantity(result_q, expected_q, U.second) + + def test_compress(self): + array = jnp.array([1, 2, 3, 4]) + result = bm.compress(jnp.array([0, 1, 1, 0]), array) + self.assertTrue(jnp.all(result == jnp.compress(jnp.array([0, 1, 1, 0]), array))) + + q = [1, 2, 3, 4] * U.second + a = [0, 1, 1, 0] * U.second + result_q = bm.compress(q, a) + expected_q = jnp.compress(jnp.array([1, 2, 3, 4]), jnp.array([0, 1, 1, 0])) + assert_quantity(result_q, expected_q, U.second) + + def test_diagflat(self): + array = jnp.array([1, 2, 3]) + result = bm.diagflat(array) + self.assertTrue(jnp.all(result == jnp.diagflat(array))) + + q = [1, 2, 3] * U.second + result_q = bm.diagflat(q) + expected_q = jnp.diagflat(jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_diagonal(self): + array = jnp.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + result = bm.diagonal(array) + self.assertTrue(jnp.all(result == jnp.diagonal(array))) + + q = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] * U.second + result_q = bm.diagonal(q) + expected_q = jnp.diagonal(jnp.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) + assert_quantity(result_q, expected_q, U.second) + + def test_choose(self): + choices = [jnp.array([1, 2, 3]), jnp.array([4, 5, 6]), jnp.array([7, 8, 9])] + result = bm.choose(jnp.array([0, 1, 2]), choices) + self.assertTrue(jnp.all(result == jnp.choose(jnp.array([0, 1, 2]), choices))) + + q = [0, 1, 2] * U.second + q = q.astype(jnp.int64) + result_q = bm.choose(q, choices) + expected_q = jnp.choose(jnp.array([0, 1, 2]), choices) + assert_quantity(result_q, expected_q, U.second) + + def test_ravel(self): + array = jnp.array([[1, 2, 3], [4, 5, 6]]) + result = bm.ravel(array) + self.assertTrue(jnp.all(result == jnp.ravel(array))) + + q = [[1, 2, 3], [4, 5, 6]] * U.second + result_q = bm.ravel(q) + expected_q = jnp.ravel(jnp.array([[1, 2, 3], [4, 5, 6]])) + assert_quantity(result_q, expected_q, U.second) + + # return_quantity = False + def test_argsort(self): + array = jnp.array([2, 3, 1]) + result = bm.argsort(array) + self.assertTrue(jnp.all(result == jnp.argsort(array))) + + q = [2, 3, 1] * U.second + result_q = bm.argsort(q) + expected_q = jnp.argsort(jnp.array([2, 3, 1])) + assert jnp.all(result_q == expected_q) + + def test_argmax(self): + array = jnp.array([2, 3, 1]) + result = bm.argmax(array) + self.assertTrue(result == jnp.argmax(array)) + + q = [2, 3, 1] * U.second + result_q = bm.argmax(q) + expected_q = jnp.argmax(jnp.array([2, 3, 1])) + assert result_q == expected_q + + def test_argmin(self): + array = jnp.array([2, 3, 1]) + result = bm.argmin(array) + self.assertTrue(result == jnp.argmin(array)) + + q = [2, 3, 1] * U.second + result_q = bm.argmin(q) + expected_q = jnp.argmin(jnp.array([2, 3, 1])) + assert result_q == expected_q + + def test_argwhere(self): + array = jnp.array([0, 1, 2]) + result = bm.argwhere(array) + self.assertTrue(jnp.all(result == jnp.argwhere(array))) + + q = [0, 1, 2] * U.second + result_q = bm.argwhere(q) + expected_q = jnp.argwhere(jnp.array([0, 1, 2])) + assert jnp.all(result_q == expected_q) + + def test_nonzero(self): + array = jnp.array([0, 1, 2]) + result = bm.nonzero(array) + expected = jnp.nonzero(array) + for r, e in zip(result, expected): + self.assertTrue(jnp.array_equal(r, e)) + + q = [0, 1, 2] * U.second + result_q = bm.nonzero(q) + expected_q = jnp.nonzero(jnp.array([0, 1, 2])) + for r, e in zip(result_q, expected_q): + assert jnp.all(r == e) + + def test_flatnonzero(self): + array = jnp.array([0, 1, 2]) + result = bm.flatnonzero(array) + self.assertTrue(jnp.all(result == jnp.flatnonzero(array))) + + q = [0, 1, 2] * U.second + result_q = bm.flatnonzero(q) + expected_q = jnp.flatnonzero(jnp.array([0, 1, 2])) + assert jnp.all(result_q == expected_q) + + def test_searchsorted(self): + array = jnp.array([1, 2, 3]) + result = bm.searchsorted(array, 2) + self.assertTrue(result == jnp.searchsorted(array, 2)) + + q = [0, 1, 2] * U.second + result_q = bm.searchsorted(q, 2) + expected_q = jnp.searchsorted(jnp.array([0, 1, 2]), 2) + assert result_q == expected_q + + def test_extract(self): + array = jnp.array([1, 2, 3]) + result = bm.extract(array > 1, array) + self.assertTrue(jnp.all(result == jnp.extract(array > 1, array))) + + q = [1, 2, 3] * U.second + a = array * U.second + result_q = bm.extract(q > 1 * U.second, a) + expected_q = jnp.extract(jnp.array([0, 1, 2]), jnp.array([1, 2, 3])) + assert jnp.all(result_q == expected_q) + + def test_count_nonzero(self): + array = jnp.array([1, 0, 2, 0, 3, 0]) + result = bm.count_nonzero(array) + self.assertTrue(result == jnp.count_nonzero(array)) + + q = [1, 0, 2, 0, 3, 0] * U.second + result_q = bm.count_nonzero(q) + expected_q = jnp.count_nonzero(jnp.array([1, 0, 2, 0, 3, 0])) + assert result_q == expected_q + + +class TestElementwiseBitOperationsUnary(unittest.TestCase): + def test_bitwise_not(self): + result = bm.bitwise_not(jnp.array([0b1100])) + self.assertTrue(jnp.all(result == jnp.bitwise_not(jnp.array([0b1100])))) + + with pytest.raises(ValueError): + q = [0b1100] * U.second + result_q = bm.bitwise_not(q) + + def test_invert(self): + result = bm.invert(jnp.array([0b1100])) + self.assertTrue(jnp.all(result == jnp.invert(jnp.array([0b1100])))) + + with pytest.raises(ValueError): + q = [0b1100] * U.second + result_q = bm.invert(q) + + def test_left_shift(self): + result = bm.left_shift(jnp.array([0b0100]), 2) + self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b0100]), 2))) + + with pytest.raises(ValueError): + q = [0b0100] * U.second + result_q = bm.left_shift(q, 2) + + def test_right_shift(self): + result = bm.right_shift(jnp.array([0b0100]), 2) + self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b0100]), 2))) + + with pytest.raises(ValueError): + q = [0b0100] * U.second + result_q = bm.right_shift(q, 2) + + +class TestElementwiseBitOperationsBinary(unittest.TestCase): + + def test_bitwise_and(self): + result = bm.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010])) + self.assertTrue(jnp.all(result == jnp.bitwise_and(jnp.array([0b1100]), jnp.array([0b1010])))) + + with pytest.raises(ValueError): + q1 = [0b1100] * U.second + q2 = [0b1010] * U.second + result_q = bm.bitwise_and(q1, q2) + + def test_bitwise_or(self): + result = bm.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010])) + self.assertTrue(jnp.all(result == jnp.bitwise_or(jnp.array([0b1100]), jnp.array([0b1010])))) + + with pytest.raises(ValueError): + q1 = [0b1100] * U.second + q2 = [0b1010] * U.second + result_q = bm.bitwise_or(q1, q2) + + def test_bitwise_xor(self): + result = bm.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010])) + self.assertTrue(jnp.all(result == jnp.bitwise_xor(jnp.array([0b1100]), jnp.array([0b1010])))) + + with pytest.raises(ValueError): + q1 = [0b1100] * U.second + q2 = [0b1010] * U.second + result_q = bm.bitwise_xor(q1, q2) + + +class TestLogicFuncsUnary(unittest.TestCase): + def test_all(self): + result = bm.all(jnp.array([True, True, True])) + self.assertTrue(result == jnp.all(jnp.array([True, True, True]))) + + with pytest.raises(ValueError): + q = [True, True, True] * U.second + result_q = bm.all(q) + + def test_any(self): + result = bm.any(jnp.array([False, True, False])) + self.assertTrue(result == jnp.any(jnp.array([False, True, False]))) + + with pytest.raises(ValueError): + q = [False, True, False] * U.second + result_q = bm.any(q) + + def test_logical_not(self): + result = bm.logical_not(jnp.array([True, False])) + self.assertTrue(jnp.all(result == jnp.logical_not(jnp.array([True, False])))) + + with pytest.raises(ValueError): + q = [True, False] * U.second + result_q = bm.logical_not(q) + + +class TestLogicFuncsBinary(unittest.TestCase): + + def test_equal(self): + result = bm.equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) + self.assertTrue(jnp.all(result == jnp.equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.equal(q1, q2) + expected_q = jnp.equal(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + with pytest.raises(DimensionMismatchError): + q1 = [1, 2, 3] * U.second + q2 = [1, 2, 4] * U.volt + result_q = bm.equal(q1, q2) + + def test_not_equal(self): + result = bm.not_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) + self.assertTrue(jnp.all(result == jnp.not_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.not_equal(q1, q2) + expected_q = jnp.not_equal(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_greater(self): + result = bm.greater(jnp.array([1, 2, 3]), jnp.array([0, 2, 4])) + self.assertTrue(jnp.all(result == jnp.greater(jnp.array([1, 2, 3]), jnp.array([0, 2, 4])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.greater(q1, q2) + expected_q = jnp.greater(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_greater_equal(self): + result = bm.greater_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 2])) + self.assertTrue(jnp.all(result == jnp.greater_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 2])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.greater_equal(q1, q2) + expected_q = jnp.greater_equal(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_less(self): + result = bm.less(jnp.array([1, 2, 3]), jnp.array([2, 2, 2])) + self.assertTrue(jnp.all(result == jnp.less(jnp.array([1, 2, 3]), jnp.array([2, 2, 2])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.less(q1, q2) + expected_q = jnp.less(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_less_equal(self): + result = bm.less_equal(jnp.array([1, 2, 3]), jnp.array([2, 2, 2])) + self.assertTrue(jnp.all(result == jnp.less_equal(jnp.array([1, 2, 3]), jnp.array([2, 2, 2])))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.less_equal(q1, q2) + expected_q = jnp.less_equal(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_array_equal(self): + result = bm.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) + self.assertTrue(result == jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))) + + q1 = [1, 2, 3] * U.second + q2 = [2, 3, 4] * U.second + result_q = bm.array_equal(q1, q2) + expected_q = jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([2, 3, 4])) + assert_quantity(result_q, expected_q, None) + + def test_isclose(self): + result = bm.isclose(jnp.array([1.0, 2.0]), jnp.array([1.0, 2.1]), atol=0.2) + self.assertTrue(jnp.all(result == jnp.isclose(jnp.array([1.0, 2.0]), jnp.array([1.0, 2.1]), atol=0.2))) + + q1 = [1.0, 2.0] * U.second + q2 = [2.0, 3.0] * U.second + result_q = bm.isclose(q1, q2, atol=0.2) + expected_q = jnp.isclose(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0]), atol=0.2) + assert_quantity(result_q, expected_q, None) + + def test_allclose(self): + result = bm.allclose(jnp.array([1.0, 2.0]), jnp.array([1.0, 2.1]), atol=0.2) + self.assertTrue(result == jnp.allclose(jnp.array([1.0, 2.0]), jnp.array([1.0, 2.1]), atol=0.2)) + + q1 = [1.0, 2.0] * U.second + q2 = [2.0, 3.0] * U.second + result_q = bm.allclose(q1, q2, atol=0.2) + expected_q = jnp.allclose(jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0]), atol=0.2) + assert_quantity(result_q, expected_q, None) + + def test_logical_and(self): + result = bm.logical_and(jnp.array([True, False]), jnp.array([False, False])) + self.assertTrue(jnp.all(result == jnp.logical_and(jnp.array([True, False]), jnp.array([False, False])))) + + q1 = [True, False] * U.second + q2 = [False, False] * U.second + result_q = bm.logical_and(q1, q2) + expected_q = jnp.logical_and(jnp.array([True, False]), jnp.array([False, False])) + assert_quantity(result_q, expected_q, None) + + def test_logical_or(self): + result = bm.logical_or(jnp.array([True, False]), jnp.array([False, False])) + self.assertTrue(jnp.all(result == jnp.logical_or(jnp.array([True, False]), jnp.array([False, False])))) + + q1 = [True, False] * U.second + q2 = [False, False] * U.second + result_q = bm.logical_or(q1, q2) + expected_q = jnp.logical_or(jnp.array([True, False]), jnp.array([False, False])) + assert_quantity(result_q, expected_q, None) + + def test_logical_xor(self): + result = bm.logical_xor(jnp.array([True, False]), jnp.array([False, False])) + self.assertTrue(jnp.all(result == jnp.logical_xor(jnp.array([True, False]), jnp.array([False, False])))) + + q1 = [True, False] * U.second + q2 = [False, False] * U.second + result_q = bm.logical_xor(q1, q2) + expected_q = jnp.logical_xor(jnp.array([True, False]), jnp.array([False, False])) + assert_quantity(result_q, expected_q, None) + + +class TestIndexingFuncs(unittest.TestCase): + + def test_where(self): + array = jnp.array([1, 2, 3, 4, 5]) + result = bm.where(array > 2, array, 0) + self.assertTrue(jnp.all(result == jnp.where(array > 2, array, 0))) + + q = [1, 2, 3, 4, 5] * U.second + result_q = bm.where(q > 2 * U.second, q, 0) + expected_q = jnp.where(jnp.array([1, 2, 3, 4, 5]) > 2, jnp.array([1, 2, 3, 4, 5]), 0) + assert_quantity(result_q, expected_q, U.second) + + def test_tril_indices(self): + result = bm.tril_indices(3) + expected = jnp.tril_indices(3) + for i in range(2): + self.assertTrue(jnp.all(result[i] == expected[i])) + + def test_tril_indices_from(self): + array = jnp.ones((3, 3)) + result = bm.tril_indices_from(array) + expected = jnp.tril_indices_from(array) + for i in range(2): + self.assertTrue(jnp.all(result[i] == expected[i])) + + def test_triu_indices(self): + result = bm.triu_indices(3) + expected = jnp.triu_indices(3) + for i in range(2): + self.assertTrue(jnp.all(result[i] == expected[i])) + + def test_triu_indices_from(self): + array = jnp.ones((3, 3)) + result = bm.triu_indices_from(array) + expected = jnp.triu_indices_from(array) + for i in range(2): + self.assertTrue(jnp.all(result[i] == expected[i])) + + def test_take(self): + array = jnp.array([4, 3, 5, 7, 6, 8]) + indices = jnp.array([0, 1, 4]) + result = bm.take(array, indices) + self.assertTrue(jnp.all(result == jnp.take(array, indices))) + + q = [4, 3, 5, 7, 6, 8] * U.second + i = jnp.array([0, 1, 4]) + result_q = bm.take(q, i) + expected_q = jnp.take(jnp.array([4, 3, 5, 7, 6, 8]), jnp.array([0, 1, 4])) + assert_quantity(result_q, expected_q, U.second) + + def test_select(self): + condlist = [jnp.array([True, False, True]), jnp.array([False, True, False])] + choicelist = [jnp.array([1, 2, 3]), jnp.array([4, 5, 6])] + result = bm.select(condlist, choicelist, default=0) + self.assertTrue(jnp.all(result == jnp.select(condlist, choicelist, default=0))) + + c = [jnp.array([True, False, True]), jnp.array([False, True, False])] + ch = [[1, 2, 3] * U.second, [4, 5, 6] * U.second] + result_q = bm.select(c, ch, default=0) + expected_q = jnp.select([jnp.array([True, False, True]), jnp.array([False, True, False])], + [jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], default=0) + assert_quantity(result_q, expected_q, U.second) + + +class TestWindowFuncs(unittest.TestCase): + + def test_bartlett(self): + result = bm.bartlett(5) + self.assertTrue(jnp.all(result == jnp.bartlett(5))) + + def test_blackman(self): + result = bm.blackman(5) + self.assertTrue(jnp.all(result == jnp.blackman(5))) + + def test_hamming(self): + result = bm.hamming(5) + self.assertTrue(jnp.all(result == jnp.hamming(5))) + + def test_hanning(self): + result = bm.hanning(5) + self.assertTrue(jnp.all(result == jnp.hanning(5))) + + def test_kaiser(self): + result = bm.kaiser(5, 0.5) + self.assertTrue(jnp.all(result == jnp.kaiser(5, 0.5))) + + +class TestConstants(unittest.TestCase): + + def test_constants(self): + self.assertTrue(bm.e == jnp.e) + self.assertTrue(bm.pi == jnp.pi) + self.assertTrue(bm.inf == jnp.inf) + + +class TestLinearAlgebra(unittest.TestCase): + + def test_dot(self): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + result = bm.dot(a, b) + self.assertTrue(result == jnp.dot(a, b)) + + q1 = [1, 2] * U.second + q2 = [3, 4] * U.volt + result_q = bm.dot(q1, q2) + expected_q = jnp.dot(jnp.array([1, 2]), jnp.array([3, 4])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_vdot(self): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + result = bm.vdot(a, b) + self.assertTrue(result == jnp.vdot(a, b)) + + q1 = [1, 2] * U.second + q2 = [3, 4] * U.volt + result_q = bm.vdot(q1, q2) + expected_q = jnp.vdot(jnp.array([1, 2]), jnp.array([3, 4])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_inner(self): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + result = bm.inner(a, b) + self.assertTrue(result == jnp.inner(a, b)) + + q1 = [1, 2] * U.second + q2 = [3, 4] * U.volt + result_q = bm.inner(q1, q2) + expected_q = jnp.inner(jnp.array([1, 2]), jnp.array([3, 4])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_outer(self): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + result = bm.outer(a, b) + self.assertTrue(jnp.all(result == jnp.outer(a, b))) + + q1 = [1, 2] * U.second + q2 = [3, 4] * U.volt + result_q = bm.outer(q1, q2) + expected_q = jnp.outer(jnp.array([1, 2]), jnp.array([3, 4])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_kron(self): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + result = bm.kron(a, b) + self.assertTrue(jnp.all(result == jnp.kron(a, b))) + + q1 = [1, 2] * U.second + q2 = [3, 4] * U.volt + result_q = bm.kron(q1, q2) + expected_q = jnp.kron(jnp.array([1, 2]), jnp.array([3, 4])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_matmul(self): + a = jnp.array([[1, 2], [3, 4]]) + b = jnp.array([[5, 6], [7, 8]]) + result = bm.matmul(a, b) + self.assertTrue(jnp.all(result == jnp.matmul(a, b))) + + q1 = [[1, 2], [3, 4]] * U.second + q2 = [[5, 6], [7, 8]] * U.volt + result_q = bm.matmul(q1, q2) + expected_q = jnp.matmul(jnp.array([[1, 2], [3, 4]]), jnp.array([[5, 6], [7, 8]])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + def test_trace(self): + a = jnp.array([[1, 2], [3, 4]]) + result = bm.trace(a) + self.assertTrue(result == jnp.trace(a)) + + q = [[1, 2], [3, 4]] * U.second + result_q = bm.trace(q) + expected_q = jnp.trace(jnp.array([[1, 2], [3, 4]])) + assert_quantity(result_q, expected_q, U.second) + + +class TestDataTypes(unittest.TestCase): + + def test_dtype(self): + array = jnp.array([1, 2, 3]) + result = bm.dtype(array) + self.assertTrue(result == jnp.dtype(array)) + + q = [1, 2, 3] * U.second + q = q.astype(jnp.int64) + result_q = bm.dtype(q) + expected_q = jnp.dtype(jnp.array([1, 2, 3], dtype=jnp.int64)) + self.assertTrue(result_q == expected_q) + + def test_finfo(self): + result = bm.finfo(jnp.float32) + self.assertTrue(result == jnp.finfo(jnp.float32)) + + q = 1 * U.second + q = q.astype(jnp.float64) + result_q = bm.finfo(q) + expected_q = jnp.finfo(jnp.float64) + self.assertTrue(result_q == expected_q) + + def test_iinfo(self): + result = bm.iinfo(jnp.int32) + expected = jnp.iinfo(jnp.int32) + self.assertEqual(result.min, expected.min) + self.assertEqual(result.max, expected.max) + self.assertEqual(result.dtype, expected.dtype) + + q = 1 * U.second + q = q.astype(jnp.int32) + result_q = bm.iinfo(q) + expected_q = jnp.iinfo(jnp.int32) + self.assertEqual(result_q.min, expected_q.min) + self.assertEqual(result_q.max, expected_q.max) + self.assertEqual(result_q.dtype, expected_q.dtype) + + +class TestMore(unittest.TestCase): + def test_broadcast_arrays(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([[4], [5]]) + result = bm.broadcast_arrays(a, b) + self.assertTrue(jnp.all(result[0] == jnp.broadcast_arrays(a, b)[0])) + self.assertTrue(jnp.all(result[1] == jnp.broadcast_arrays(a, b)[1])) + + q1 = [1, 2, 3] * U.second + q2 = [[4], [5]] * U.second + result_q = bm.broadcast_arrays(q1, q2) + expected_q = jnp.broadcast_arrays(jnp.array([1, 2, 3]), jnp.array([[4], [5]])) + assert_quantity(result_q, expected_q, U.second) + + def test_broadcast_shapes(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([[4], [5]]) + result = bm.broadcast_shapes(a.shape, b.shape) + self.assertTrue(result == jnp.broadcast_shapes(a.shape, b.shape)) + + def test_einsum(self): + a = jnp.array([1, 2, 3]) + b = jnp.array([4, 5]) + result = bm.einsum('i,j->ij', a, b) + self.assertTrue(jnp.all(result == jnp.einsum('i,j->ij', a, b))) + + q1 = [1, 2, 3] * U.second + q2 = [4, 5] * U.volt + result_q = bm.einsum('i,j->ij', q1, q2) + expected_q = jnp.einsum('i,j->ij', jnp.array([1, 2, 3]), jnp.array([4, 5])) + assert_quantity(result_q, expected_q, U.second * U.volt) + + q1 = [1, 2, 3] * U.second + q2 = [1, 2, 3] * U.second + result_q = bm.einsum('i,i->i', q1, q2) + expected_q = jnp.einsum('i,i->i', jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) + assert_quantity(result_q, expected_q, U.second) + + def test_gradient(self): + f = jnp.array([1, 2, 4, 7, 11, 16], dtype=float) + result = bm.gradient(f) + self.assertTrue(jnp.all(bm.allclose(result, jnp.gradient(f)))) + + q = [1, 2, 4, 7, 11, 16] * U.second + result_q = bm.gradient(q) + expected_q = jnp.gradient(jnp.array([1, 2, 4, 7, 11, 16])) + assert_quantity(result_q, expected_q, U.second) + + q1 = jnp.array([[1, 2, 6], [3, 4, 5]]) * U.second + dx = 2. * U.meter + # y = [1., 1.5, 3.5] * U.second + result_q = bm.gradient(q1, dx) + expected_q = jnp.gradient(jnp.array([[1, 2, 6], [3, 4, 5]]), 2.) + assert_quantity(result_q[0], expected_q[0], U.second / U.meter) + assert_quantity(result_q[1], expected_q[1], U.second / U.meter) + + def test_intersect1d(self): + a = jnp.array([1, 2, 3, 4, 5]) + b = jnp.array([3, 4, 5, 6, 7]) + result = bm.intersect1d(a, b) + self.assertTrue(jnp.all(result == jnp.intersect1d(a, b))) + + q1 = [1, 2, 3, 4, 5] * U.second + q2 = [3, 4, 5, 6, 7] * U.second + result_q = bm.intersect1d(q1, q2) + expected_q = jnp.intersect1d(jnp.array([1, 2, 3, 4, 5]), jnp.array([3, 4, 5, 6, 7])) + assert_quantity(result_q, expected_q, U.second) + + def test_nan_to_num(self): + a = jnp.array([1, 2, 3, 4, jnp.nan]) + result = bm.nan_to_num(a) + self.assertTrue(jnp.all(result == jnp.nan_to_num(a))) + + q = [1, 2, 3, 4, jnp.nan] * U.second + result_q = bm.nan_to_num(q) + expected_q = jnp.nan_to_num(jnp.array([1, 2, 3, 4, jnp.nan])) + assert_quantity(result_q, expected_q, U.second) + + def nanargmax(self): + a = jnp.array([1, 2, 3, 4, jnp.nan]) + result = bm.nanargmax(a) + self.assertTrue(result == jnp.nanargmax(a)) + + q = [1, 2, 3, 4, jnp.nan] * U.second + result_q = bm.nanargmax(q) + expected_q = jnp.nanargmax(jnp.array([1, 2, 3, 4, jnp.nan])) + self.assertTrue(result_q == expected_q) + + def nanargmin(self): + a = jnp.array([1, 2, 3, 4, jnp.nan]) + result = bm.nanargmin(a) + self.assertTrue(result == jnp.nanargmin(a)) + + q = [1, 2, 3, 4, jnp.nan] * U.second + result_q = bm.nanargmin(q) + expected_q = jnp.nanargmin(jnp.array([1, 2, 3, 4, jnp.nan])) + self.assertTrue(result_q == expected_q) + + def test_rot90(self): + a = jnp.array([[1, 2], [3, 4]]) + result = bm.rot90(a) + self.assertTrue(jnp.all(result == jnp.rot90(a))) + + q = [[1, 2], [3, 4]] * U.second + result_q = bm.rot90(q) + expected_q = jnp.rot90(jnp.array([[1, 2], [3, 4]])) + assert_quantity(result_q, expected_q, U.second) + + def test_tensordot(self): + a = jnp.array([[1, 2], [3, 4]]) + b = jnp.array([[1, 2], [3, 4]]) + result = bm.tensordot(a, b) + self.assertTrue(jnp.all(result == jnp.tensordot(a, b))) + + q1 = [[1, 2], [3, 4]] * U.second + q2 = [[1, 2], [3, 4]] * U.second + result_q = bm.tensordot(q1, q2) + expected_q = jnp.tensordot(jnp.array([[1, 2], [3, 4]]), jnp.array([[1, 2], [3, 4]])) + assert_quantity(result_q, expected_q, U.second ** 2) diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py new file mode 100644 index 0000000..7b5f0f2 --- /dev/null +++ b/brainunit/math/_utils.py @@ -0,0 +1,88 @@ +import functools +from typing import Callable + +from jax.tree_util import tree_map + +from brainunit import Quantity + + +def _as_jax_array_(obj): + return obj.value if isinstance(obj, Quantity) else obj + + +def _is_leaf(a): + return isinstance(a, Quantity) + + +def _compatible_with_quantity( + fun: Callable, + return_quantity: bool = True, + module: str = '' +): + func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun + + @functools.wraps(func_to_wrap) + def new_fun(*args, **kwargs): + unit = None + if isinstance(args[0], Quantity): + unit = args[0].unit + elif isinstance(args[0], tuple): + if len(args[0]) == 1: + unit = args[0][0].unit if isinstance(args[0][0], Quantity) else None + elif len(args[0]) == 2: + # check all args[0] have the same unit + if all(isinstance(a, Quantity) for a in args[0]): + if all(a.unit == args[0][0].unit for a in args[0]): + unit = args[0][0].unit + else: + raise ValueError(f'Units do not match for {fun.__name__} operation.') + elif all(not isinstance(a, Quantity) for a in args[0]): + unit = None + else: + raise ValueError(f'Units do not match for {fun.__name__} operation.') + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None + if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + if 'keepdim' in kwargs: + kwargs['keepdims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keep_dims' in kwargs: + kwargs['keepdims'] = kwargs.pop('keep_dims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.pop('out') + if not isinstance(out, Quantity): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) + + if not return_quantity: + unit = None + + r = fun(*args, **kwargs) + if unit is not None: + if isinstance(r, (list, tuple)): + return [Quantity(rr, unit=unit) for rr in r] + else: + if out is None: + return Quantity(r, unit=unit) + else: + out.value = r + if out is None: + return r + else: + out.value = r + + new_fun.__doc__ = ( + f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' + f'while it is compatible with brainpy Array/Variable. \n\n' + f'Note that this function is also compatible with:\n\n' + f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' + f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' + f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' + ) + + return new_fun diff --git a/docs/apis/functional.rst b/docs/apis/functional.rst index 705a08d..ee1bcf6 100644 --- a/docs/apis/functional.rst +++ b/docs/apis/functional.rst @@ -1,76 +1,76 @@ -``braintools.functional`` module -================================ - -.. currentmodule:: braintools.functional -.. automodule:: braintools.functional - -Activation Functions --------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - tanh - relu - squareplus - softplus - soft_sign - sigmoid - silu - swish - log_sigmoid - elu - leaky_relu - hard_tanh - celu - selu - gelu - glu - logsumexp - log_softmax - softmax - standardize - one_hot - relu6 - hard_sigmoid - hard_silu - hard_swish - hard_shrink - rrelu - mish - soft_shrink - prelu - tanh_shrink - softmin - - -Normalization -------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - weight_standardization - - -Spike Operations ----------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - spike_bitwise_or - spike_bitwise_and - spike_bitwise_iand - spike_bitwise_not - spike_bitwise_xor - spike_bitwise_ixor - spike_bitwise - - +``brainunit.functional`` module +================================ + +.. currentmodule:: brainunit.functional +.. automodule:: brainunit.functional + +Activation Functions +-------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + tanh + relu + squareplus + softplus + soft_sign + sigmoid + silu + swish + log_sigmoid + elu + leaky_relu + hard_tanh + celu + selu + gelu + glu + logsumexp + log_softmax + softmax + standardize + one_hot + relu6 + hard_sigmoid + hard_silu + hard_swish + hard_shrink + rrelu + mish + soft_shrink + prelu + tanh_shrink + softmin + + +Normalization +------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + weight_standardization + + +Spike Operations +---------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + spike_bitwise_or + spike_bitwise_and + spike_bitwise_iand + spike_bitwise_not + spike_bitwise_xor + spike_bitwise_ixor + spike_bitwise + + diff --git a/docs/apis/init.rst b/docs/apis/init.rst index 5c52c1e..065a0a9 100644 --- a/docs/apis/init.rst +++ b/docs/apis/init.rst @@ -1,43 +1,29 @@ -``braintools.init`` module -========================== - -.. currentmodule:: braintools.init -.. automodule:: braintools.init - - -Initializers ------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - Initializer - ZeroInit - Constant - Identity - Normal - TruncatedNormal - Uniform - VarianceScaling - KaimingUniform - KaimingNormal - XavierUniform - XavierNormal - LecunUniform - LecunNormal - Orthogonal - DeltaOrthogonal - - -Initialization Helpers ----------------------- - -.. autosummary:: - :toctree: generated/ - - parameter - state - noise - to_size \ No newline at end of file +``brainunit.init`` module +========================== + +.. currentmodule:: brainunit.init +.. automodule:: brainunit.init + +.. autosummary:: + :toctree: generated/ + + parameter + state + noise + to_size + Initializer + ZeroInit + Constant + Identity + Normal + TruncatedNormal + Uniform + VarianceScaling + KaimingUniform + KaimingNormal + XavierUniform + XavierNormal + LecunUniform + LecunNormal + Orthogonal + DeltaOrthogonal diff --git a/docs/apis/metric.rst b/docs/apis/metric.rst index 33cfc52..bcda0d1 100644 --- a/docs/apis/metric.rst +++ b/docs/apis/metric.rst @@ -1,121 +1,121 @@ -``braintools.metric`` module -============================ - -.. currentmodule:: braintools.metric -.. automodule:: braintools.metric - -Classification Losses ---------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - sigmoid_binary_cross_entropy - hinge_loss - perceptron_loss - softmax_cross_entropy - softmax_cross_entropy_with_integer_labels - multiclass_hinge_loss - multiclass_perceptron_loss - poly_loss_cross_entropy - kl_divergence - kl_divergence_with_log_targets - convex_kl_divergence - ctc_loss - ctc_loss_with_forward_probs - sigmoid_focal_loss - - -Correlation ------------ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - cross_correlation - voltage_fluctuation - matrix_correlation - weighted_correlation - functional_connectivity - functional_connectivity_dynamics - - -Fenchel-Young Loss ------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - make_fenchel_young_loss - - -Spike Firing ------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - raster_plot - firing_rate - - -Local Field Potential ---------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - unitary_LFP - - -Ranking Losses --------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - ranking_softmax_loss - - -Regression Losses ------------------ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - squared_error - absolute_error - l1_loss - l2_loss - l2_norm - huber_loss - log_cosh - cosine_similarity - cosine_distance - - -Smoothing Losses ----------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - smooth_labels - - +``brainunit.metric`` module +============================ + +.. currentmodule:: brainunit.metric +.. automodule:: brainunit.metric + +Classification Losses +--------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + sigmoid_binary_cross_entropy + hinge_loss + perceptron_loss + softmax_cross_entropy + softmax_cross_entropy_with_integer_labels + multiclass_hinge_loss + multiclass_perceptron_loss + poly_loss_cross_entropy + kl_divergence + kl_divergence_with_log_targets + convex_kl_divergence + ctc_loss + ctc_loss_with_forward_probs + sigmoid_focal_loss + + +Correlation +----------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + cross_correlation + voltage_fluctuation + matrix_correlation + weighted_correlation + functional_connectivity + functional_connectivity_dynamics + + +Fenchel-Young Loss +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + make_fenchel_young_loss + + +Spike Firing +------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + raster_plot + firing_rate + + +Local Field Potential +--------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + unitary_LFP + + +Ranking Losses +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ranking_softmax_loss + + +Regression Losses +----------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + squared_error + absolute_error + l1_loss + l2_loss + l2_norm + huber_loss + log_cosh + cosine_similarity + cosine_distance + + +Smoothing Losses +---------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + smooth_labels + + diff --git a/docs/apis/optim.rst b/docs/apis/optim.rst index c73cab8..31c7a53 100644 --- a/docs/apis/optim.rst +++ b/docs/apis/optim.rst @@ -1,50 +1,50 @@ -``braintools.optim`` module -=========================== - -.. currentmodule:: braintools.optim -.. automodule:: braintools.optim - -SGD Optimizers --------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - to_same_dict_tree - OptimState - Optimizer - SGD - Momentum - MomentumNesterov - Adagrad - Adadelta - RMSProp - Adam - LARS - Adan - AdamW - - -Learning Rate Schedulers ------------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - LearningRateScheduler - ConstantLR - StepLR - MultiStepLR - CosineAnnealingLR - CosineAnnealingWarmRestarts - ExponentialLR - ExponentialDecayLR - InverseTimeDecayLR - PolynomialDecayLR - PiecewiseConstantLR - - +``brainunit.optim`` module +=========================== + +.. currentmodule:: brainunit.optim +.. automodule:: brainunit.optim + +SGD Optimizers +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + to_same_dict_tree + OptimState + Optimizer + SGD + Momentum + MomentumNesterov + Adagrad + Adadelta + RMSProp + Adam + LARS + Adan + AdamW + + +Learning Rate Schedulers +------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + LearningRateScheduler + ConstantLR + StepLR + MultiStepLR + CosineAnnealingLR + CosineAnnealingWarmRestarts + ExponentialLR + ExponentialDecayLR + InverseTimeDecayLR + PolynomialDecayLR + PiecewiseConstantLR + + diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 79b515e..b3220c7 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -319,9 +319,9 @@ def _section(header, numpy_mod, brainpy_mod, jax_mod, klass=None, is_jax=False): def main(): os.makedirs('apis/auto/', exist_ok=True) - _write_module(module_name='braintools.init', + _write_module(module_name='brainunit.init', filename='apis/init.rst', - header='``braintools.init`` module') + header='``brainunit.init`` module') module_and_name = [ ('_classification', 'Classification Losses'), @@ -333,9 +333,9 @@ def main(): ('_regression', 'Regression Losses'), ('_smoothing', 'Smoothing Losses'), ] - _write_submodules(module_name='braintools.metric', + _write_submodules(module_name='brainunit.metric', filename='apis/metric.rst', - header='``braintools.metric`` module', + header='``brainunit.metric`` module', submodule_names=[k[0] for k in module_and_name], section_names=[k[1] for k in module_and_name]) @@ -344,9 +344,9 @@ def main(): ('_normalization', 'Normalization'), ('_spikes', 'Spike Operations'), ] - _write_submodules(module_name='braintools.functional', + _write_submodules(module_name='brainunit.functional', filename='apis/functional.rst', - header='``braintools.functional`` module', + header='``brainunit.functional`` module', submodule_names=[k[0] for k in module_and_name], section_names=[k[1] for k in module_and_name]) @@ -354,9 +354,9 @@ def main(): ('_sgd_optimizer', 'SGD Optimizers'), ('_lr_scheduler', 'Learning Rate Schedulers'), ] - _write_submodules(module_name='braintools.optim', + _write_submodules(module_name='brainunit.optim', filename='apis/optim.rst', - header='``braintools.optim`` module', + header='``brainunit.optim`` module', submodule_names=[k[0] for k in module_and_name], section_names=[k[1] for k in module_and_name]) diff --git a/docs/index.rst b/docs/index.rst index 4e4b932..d8dbd9c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,5 +57,6 @@ See also the BDP ecosystem :hidden: :maxdepth: 2 + tutorials/physical_units.rst api.rst diff --git a/docs/tutorials.rst b/docs/tutorials.rst new file mode 100644 index 0000000..ed16e77 --- /dev/null +++ b/docs/tutorials.rst @@ -0,0 +1,7 @@ +Tutorials Documentation +======================= + +.. toctree:: + :maxdepth: 2 + + tutorials/physical_units \ No newline at end of file diff --git a/docs/tutorials/physical_units.ipynb b/docs/tutorials/physical_units.ipynb new file mode 100644 index 0000000..ea89738 --- /dev/null +++ b/docs/tutorials/physical_units.ipynb @@ -0,0 +1,2996 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Physical Units" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Braincore includes a system for physical units. The base units are defined by their standard SI unit names:\n", + "`amp`/`ampere`, `kilogram`/`kilogramme`, `second`, `metre`/`meter`, `kilogram`, `mole`/`mol`, `kelvin`, and `candela`. In addition to these base units, braincore defines a set of derived units: `coulomb`, `farad`, `gram`/`gramme`, `hertz`, `joule`, `liter`/\n", + "`litre`, `molar`, `pascal`, `ohm`, `siemens`, `volt`, `watt`,\n", + "together with prefixed versions (e.g. `msiemens = 0.001*siemens`) using the\n", + "prefixes `p, n, u, m, k, M, G, T` (two exceptions to this rule: `kilogram`\n", + "is not defined with any additional prefixes, and `metre` and `meter` are\n", + "additionaly defined with the \"centi\" prefix, i.e. `cmetre`/`cmeter`).\n", + "For convenience, a couple of additional useful standard abbreviations such as\n", + "`cm` (instead of `cmetre`/`cmeter`), `nS` (instead of `nsiemens`),\n", + "`ms` (instead of `msecond`), `Hz` (instead of `hertz`), `mM`\n", + "(instead of `mmolar`) are included. To avoid clashes with common variable\n", + "names, no one-letter abbreviations are provided (e.g. you can use `mV` or\n", + "`nS`, but *not* `V` or `S`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Importing units\n", + "Braincore generates standard names for units, combining the unit name (e.g. “siemens”) with a prefixes (e.g. “m”), and also generates squared and cubed versions by appending a number. For example, the units “msiemens”, “siemens2”, “usiemens3” are all predefined. You can import these units from the package `briancore.units` – accordingly, an `from braincore.units import *` will result in everything being imported.\n", + "\n", + "We recommend importing only the units you need, to have a cleaner namespace. For example, `import braincore.units as U` and then using `U.msiemens` instead of `msiemens`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\13107\\.conda\\envs\\brainpy-dev\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import braincore.units as U" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using units\n", + "You can generate a physical quantity by multiplying a scalar or ndarray with its physical unit:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20. * msecond" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tau = 20 * U.ms\n", + "tau" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([10., 20., 30.], dtype=float32) * hertz" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates = [10, 20, 30] * U.Hz\n", + "rates" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[10., 20., 30.],\n", + " [20., 30., 40.]], dtype=float32) * hertz" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates = [[10, 20, 30], [20, 30, 40]] * U.Hz\n", + "rates" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Braincore will check the consistency of operations on units and raise an error for dimensionality mismatches:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cannot calculate ... += 1, units do not match (units are s and 1).\n" + ] + } + ], + "source": [ + "try:\n", + " tau += 1 # ms? second?\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cannot calculate 3.0 + 3.0, units do not match (units are kg and A).\n" + ] + } + ], + "source": [ + "try:\n", + " 3 * U.kgram + 3 * U.amp \n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basics\n", + "Numpy functions have been overwritten to correctly work with units.\n", + "\n", + "The important attributes of a `Quantity` object are:\n", + "- `value`: the numerical value of the quantity\n", + "- `unit`: the unit of the quantity\n", + "- `ndim`: the number of dimensions of quantity's value\n", + "- `shape`: the shape of the quantity's value\n", + "- `size`: the size of the quantity's value\n", + "- `dtype`: the dtype of the quantity's value" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### An example" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[10., 20., 30.],\n", + " [20., 30., 40.]], dtype=float32) * hertz" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[10., 20., 30.],\n", + " [20., 30., 40.]], dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates.value" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "second ** -1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates.unit" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, (2, 3), 6, dtype('float32'))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rates.ndim, rates.shape, rates.size, rates.dtype" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Quantity Creation\n", + "Creating a Quantity object can be accomplished in several ways, categorized based on the type of input used. Here, we present the methods grouped by their input types and characteristics for better clarity." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import braincore as bc\n", + "bc.environ.set(precision=64) # we recommend using 64-bit precision for better numerical stability" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "#### Scalar and Array Multiplication\n", + "- Multiplying a Scalar with a Unit\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5. * msecond" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "5 * U.ms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Multiplying a Jax nunmpy value type with a Unit:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5. * msecond" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.float64(5) * U.ms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Multiplying a Jax numpy array with a Unit:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 2., 3.]) * msecond" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.array([1, 2, 3]) * U.ms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Multiplying a List with a Unit:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 2., 3.]) * msecond" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[1, 2, 3] * U.ms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Direct Quantity Creation\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from braincore.units import Quantity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Creating a Quantity Directly with a Value" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Quantity(5.)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Creating a Quantity Directly with a Value and Unit" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5. * second" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity(5, unit=U.ms.unit)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Creating a Quantity with a Jax numpy Array of Values and a Unit" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 2., 3.]) * second" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity(jnp.array([1, 2, 3]), unit=U.ms.unit)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Creating a Quantity with a List of Values and a Unit" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 2., 3.]) * second" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity([1, 2, 3], unit=U.ms.unit)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Creating a Quantity with a List of Quantities" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([0.5, 1. ]) * second" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity([500 * U.ms, 1 * U.second])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Using the with_units Method" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([0.5, 1. ]) * second" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity.with_units(jnp.array([0.5, 1]), second=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Unitless Quantity\n", + "Quantities can be unitless, which means they have no units. If there is no unit provided, the quantity is assumed to be unitless. The following are examples of creating unitless quantities:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Quantity(ArrayImpl([1., 2., 3.]))" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity([1, 2, 3])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Quantity(ArrayImpl([1., 2., 3.]))" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity(jnp.array([1, 2, 3]))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Quantity(ArrayImpl([], dtype=float64))" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Quantity([])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Illegal Quantity Creation\n", + "The following are examples of illegal quantity creation:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All elements must have the same unit\n" + ] + } + ], + "source": [ + "try:\n", + " Quantity([500 * U.ms, 1])\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Value 'some' with dtype , >=) are supported:" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(ArrayImpl([10., 12., 14., 16., 18.]) * mvolt,\n", + " ArrayImpl([ 8., 12., 16., 20., 24.]) * mvolt)" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 = jnp.arange(10, 20, 2) * U.mV\n", + "q2 = jnp.arange(8, 27, 4) * U.mV\n", + "q1, q2" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([False, True, False, False, False], dtype=bool),\n", + " Array([ True, False, True, True, True], dtype=bool))" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 == q2, q1 != q2" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([False, False, True, True, True], dtype=bool),\n", + " Array([False, True, True, True, True], dtype=bool))" + ] + }, + "execution_count": 117, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 < q2, q1 <= q2" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([ True, False, False, False, False], dtype=bool),\n", + " Array([ True, True, False, False, False], dtype=bool))" + ] + }, + "execution_count": 118, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 > q2, q1 >= q2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Binary Operations\n", + "The binary operations add (+), subtract (-), multiply (*), divide (/), floor divide (//), remainder (%), divmod (divmod), power (**), matmul (@), shift (<<, >>), round(round) are supported:" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(ArrayImpl([1., 2., 3.]) * mvolt, ArrayImpl([2., 3., 4.]) * mvolt)" + ] + }, + "execution_count": 134, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 = jnp.array([1, 2, 3]) * U.mV\n", + "q2 = jnp.array([2, 3, 4]) * U.mV\n", + "q1, q2" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(ArrayImpl([3., 5., 7.]) * mvolt, ArrayImpl([-1., -1., -1.]) * mvolt)" + ] + }, + "execution_count": 135, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 + q2, q1 - q2" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([ 2., 6., 12.]) * mvolt2" + ] + }, + "execution_count": 136, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 * q2" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([0.5 , 0.66666667, 0.75 ], dtype=float64),\n", + " Array([0., 0., 0.], dtype=float64),\n", + " ArrayImpl([1., 2., 3.]) * mvolt)" + ] + }, + "execution_count": 137, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 / q2, q1 // q2, q1 % q2" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([0., 0., 0.], dtype=float64), ArrayImpl([1., 2., 3.]) * mvolt)" + ] + }, + "execution_count": 138, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "divmod(q1, q2)" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 4., 9.]) * mvolt2" + ] + }, + "execution_count": 139, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20. * mvolt2" + ] + }, + "execution_count": 140, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 @ q2" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(ArrayImpl([ 0., 4., 8., 12., 16.]) * volt,\n", + " ArrayImpl([0., 0., 0., 0., 1.]) * volt)" + ] + }, + "execution_count": 132, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q1 = Quantity(jnp.arange(5, dtype=jnp.int32), unit=U.mV.unit)\n", + "q1 << 2, q1 >> 2" + ] + }, + { + "cell_type": "code", + "execution_count": 151, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "round(80.23456, 2) : 80.23 mV\n", + "round(100.000056, 3) : 100. mV\n", + "round(-100.000056, 3) : -100. mV\n" + ] + } + ], + "source": [ + "q1 = 80.23456 * U.mV\n", + "q2 = 100.000056 * U.mV\n", + "q3 = -100.000056 * U.mV\n", + "print(\"round(80.23456, 2) : \", q1.round(5))\n", + "print(\"round(100.000056, 3) : \", q2.round(6))\n", + "print(\"round(-100.000056, 3) : \", q3.round(6))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Shape Manipulation\n", + "The shape of an array can be changed with various commands. Note that the following three commands all return a modified array, but do not change the original array:" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[1., 2.],\n", + " [3., 4.]]) * mvolt" + ] + }, + "execution_count": 152, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q = [[1, 2], [3, 4]] * U.mV\n", + "q" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 2., 3., 4.]) * mvolt" + ] + }, + "execution_count": 153, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[1., 3.],\n", + " [2., 4.]]) * mvolt" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.swapaxes(0, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([1., 3.]) * mvolt" + ] + }, + "execution_count": 156, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.take(jnp.array([0, 2]))" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[1., 3.],\n", + " [2., 4.]]) * mvolt" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.transpose()" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[1., 2., 1., 2.],\n", + " [3., 4., 3., 4.]]) * mvolt" + ] + }, + "execution_count": 158, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.tile(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[[1., 2.],\n", + " [3., 4.]]]) * mvolt" + ] + }, + "execution_count": 159, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.unsqueeze(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[[1., 2.],\n", + " [3., 4.]]]) * mvolt" + ] + }, + "execution_count": 162, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.expand_dims(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[[1., 2.],\n", + " [3., 4.]]]) * mvolt" + ] + }, + "execution_count": 163, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "expand_as_shape = (1, 2, 2)\n", + "q.expand_as(jnp.zeros(expand_as_shape).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[ 1., 30.],\n", + " [10., 4.]]) * mvolt" + ] + }, + "execution_count": 173, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q_put = [[1, 2], [3, 4]] * U.mV\n", + "q_put.put([[1, 0], [0, 1]], [10, 30] * U.mV)\n", + "q_put" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ArrayImpl([[1., 2.],\n", + " [3., 4.]]) * mvolt" + ] + }, + "execution_count": 174, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q_squeeze = [[1, 2], [3, 4]] * U.mV\n", + "q_squeeze.squeeze()" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[ArrayImpl([[1., 2.]]) * mvolt, ArrayImpl([[3., 4.]]) * mvolt]" + ] + }, + "execution_count": 175, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q_spilt = [[1, 2], [3, 4]] * U.mV\n", + "q_spilt.split(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Numpy Methods\n", + "All methods that make sense on quantities should work, i.e. they check for the correct units of their arguments and return quantities with units were appropriate.\n", + "\n", + "These methods defined at `braincore.math`, so you can use them by importing `import braincore.math as bm` and then using `bm.method_name`.\n", + "#### Functions that remove unit\n", + "- all\n", + "- any\n", + "- nonzero\n", + "- argmax\n", + "- argmin\n", + "- argsort\n", + "- ones_like\n", + "- zeros_like\n", + "\n", + "#### Functions that keep unit\n", + "- round\n", + "- std\n", + "- sum\n", + "- trace\n", + "- cumsum\n", + "- diagonal\n", + "- max\n", + "- mean\n", + "- min\n", + "- ptp\n", + "- ravel\n", + "- absolute\n", + "- rint\n", + "- negative\n", + "- positive\n", + "- conj\n", + "- conjugate\n", + "- floor\n", + "- ceil\n", + "- trunc\n", + "\n", + "#### Functions that change unit\n", + "- var\n", + "- multiply\n", + "- divide\n", + "- true_divide\n", + "- floor_divide\n", + "- dot\n", + "- matmul\n", + "- sqrt\n", + "- square\n", + "- reciprocal\n", + "\n", + "#### Functions that need to match unit\n", + "- add\n", + "- subtract\n", + "- maximum\n", + "- minimum\n", + "- remainder\n", + "- mod\n", + "- fmod\n", + "\n", + "#### Functions that only work with unitless quantities\n", + "- sin\n", + "- sinh\n", + "- arcsinh\n", + "- cos\n", + "- cosh\n", + "- arccos\n", + "- arccosh\n", + "- tan\n", + "- tanh\n", + "- arctan\n", + "- arctanh\n", + "- log\n", + "- log10\n", + "- exp\n", + "- expm1\n", + "- log1p\n", + "\n", + "#### Functions that compare quantities\n", + "- less\n", + "- less_equal\n", + "- greater\n", + "- greater_equal\n", + "- equal\n", + "- not_equal\n", + "\n", + "#### Functions that work on all quantities and return boolean arrays(Logical operations)\n", + "- logical_and\n", + "- logical_or\n", + "- logical_xor\n", + "- logical_not\n", + "- isreal\n", + "- iscomplex\n", + "- isfinite\n", + "- isinf\n", + "- isnan" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tutorials/physical_units.rst b/docs/tutorials/physical_units.rst new file mode 100644 index 0000000..3cf79e4 --- /dev/null +++ b/docs/tutorials/physical_units.rst @@ -0,0 +1,7 @@ +Physical Units +============== + +.. toctree:: + :maxdepth: 1 + + physical_units.ipynb \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0577db5..c0e6453 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] exclude = [ - "docs*", "tests*", "examples*", "build*", "dist*", - "braintools.egg-info*", "braintools/__pycache__*", - "braintools/__init__.py" + "docs*", "build*", "dist*", + "brainunit.egg-info*", "brainunit/__pycache__*", + "brainunit/__init__.py" ] @@ -16,8 +16,8 @@ universal = true [project] -name = "braintools" -description = "The Toolbox for Brain Dynamics Programming." +name = "brainunit" +description = "A Unit-aware System for Brain Dynamics Programming." readme = 'README.md' license = { text = 'Apache-2.0 license' } requires-python = '>=3.9' @@ -54,11 +54,11 @@ dependencies = [ dynamic = ['version'] [tool.flit.module] -name = "braintools" +name = "brainunit" [project.urls] -homepage = 'http://github.com/brainpy/braintools' -repository = 'http://github.com/brainpy/braintools' +homepage = 'http://github.com/brainpy/brainunit' +repository = 'http://github.com/brainpy/brainunit' [project.optional-dependencies] testing = [ diff --git a/setup.py b/setup.py index ab4d550..751d0c3 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ with open(os.path.join(here, 'brainunit/', '__init__.py'), 'r') as f: init_py = f.read() version = re.search('__version__ = "(.*)"', init_py).groups()[0] +print(version) if len(sys.argv) > 2 and sys.argv[2] == '--python-tag=py3': version = version else: @@ -40,7 +41,7 @@ # installation packages packages = find_packages( - exclude=["docs*", "tests*", "examples*", "build*", + exclude=["docs*", "build*", "dist*", "brainunit.egg-info*", "brainunit/__pycache__*"] )