diff --git a/README.md b/README.md index ee962c1..c8c4bf8 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@

-[``brainunit``](https://github.com/brainpy/brainunit) provides common toolboxes for brain dynamics programming (BDP). +[``brainunit``](https://github.com/brainpy/brainunit) provides a unit-aware mathematical system for brain dynamics programming (BDP). ## Installation diff --git a/brainunit/_base.py b/brainunit/_base.py index 5aac5f7..577e281 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -22,16 +22,11 @@ from contextlib import contextmanager from typing import Union, Optional, Sequence, Callable, Tuple, Any, List -import brainstate as bst import jax import jax.numpy as jnp import numpy as np -from jax.tree_util import register_pytree_node_class from jax.interpreters.partial_eval import DynamicJaxprTracer - -from ._misc import get_dtype - - +from jax.tree_util import register_pytree_node_class __all__ = [ 'Quantity', @@ -755,7 +750,7 @@ def in_best_unit(x, precision=None): def array_with_unit( floatval, unit: Dimension, - dtype: bst.typing.DTypeLike = None + dtype: jax.typing.DTypeLike = None ) -> 'Quantity': """ Create a new `Array` with the given dimensions. Calls @@ -961,7 +956,7 @@ class Quantity(object): def __init__( self, value: Any, - dtype: Optional[bst.typing.DTypeLike] = None, + dtype: Optional[jax.typing.DTypeLike] = None, dim: Dimension = DIMENSIONLESS, unit: Optional['Unit'] = None, ): @@ -987,17 +982,14 @@ def __init__( # array value if isinstance(value, Quantity): - dtype = dtype or get_dtype(value) self._dim = value.dim self._value = jnp.array(value.value, dtype=dtype) return elif isinstance(value, (np.ndarray, jax.Array)): - dtype = dtype or get_dtype(value) value = jnp.array(value, dtype=dtype) elif isinstance(value, (jnp.number, numbers.Number)): - dtype = dtype or get_dtype(value) value = jnp.array(value, dtype=dtype) elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)): @@ -1279,7 +1271,20 @@ def _check_tracer(self): @property def dtype(self): """Variable dtype.""" - return get_dtype(self._value) + a = self._value + if hasattr(a, 'dtype'): + return a.dtype + else: + if isinstance(a, bool): + return bool + elif isinstance(a, int): + return jax.dtypes.canonicalize_dtype(int) + elif isinstance(a, float): + return jax.dtypes.canonicalize_dtype(float) + elif isinstance(a, complex): + return jax.dtypes.canonicalize_dtype(complex) + else: + raise TypeError(f'Can not get dtype of {a}.') @property def shape(self) -> Tuple[int, ...]: @@ -2480,7 +2485,7 @@ def __init__( name: str = None, dispname: str = None, iscompound: bool = None, - dtype: bst.typing.DTypeLike = None, + dtype: jax.typing.DTypeLike = None, ): if dim is None: dim = DIMENSIONLESS diff --git a/brainunit/_misc.py b/brainunit/_misc.py deleted file mode 100644 index f5b18da..0000000 --- a/brainunit/_misc.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import brainstate as bst - - -def get_dtype(a): - """ - Get the dtype of a. - """ - if hasattr(a, 'dtype'): - return a.dtype - else: - if isinstance(a, bool): - return bool - elif isinstance(a, int): - return bst.environ.ditype() - elif isinstance(a, float): - return bst.environ.dftype() - elif isinstance(a, complex): - return bst.environ.dctype() - else: - raise TypeError(f'Can not get dtype of {a}.') - - diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 3a8773b..60c6dd3 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -16,13 +16,12 @@ import itertools import warnings +import brainstate as bst import jax.numpy as jnp import numpy as np import pytest from numpy.testing import assert_equal -import brainstate as bst - array = np.array bst.environ.set(precision=64) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 1e31f4d..0146ada 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + from collections.abc import Sequence -from functools import wraps -from typing import (Callable, Union, Optional, Any) +from typing import (Union, Optional, Any) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np from brainstate._utils import set_module_as from jax import Array -from .._base import (DIMENSIONLESS, - Quantity, - Unit, - fail_for_dimension_mismatch, - is_unitless, - ) +from .._base import ( + DIMENSIONLESS, + Quantity, + Unit, + fail_for_dimension_mismatch, + is_unitless, +) __all__ = [ # array creation @@ -231,10 +231,12 @@ def zeros( @set_module_as('brainunit.math') -def full_like(a: Union[Quantity, bst.typing.ArrayLike], - fill_value: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, - shape: Any = None) -> Union[Quantity, jax.Array]: +def full_like( + a: Union[Quantity, jax.typing.ArrayLike], + fill_value: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, + shape: Any = None +) -> Union[Quantity, jax.Array]: ''' Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity. else return an array of `a` filled with `fill_value`. @@ -262,7 +264,7 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def diag(a: Union[Quantity, bst.typing.ArrayLike], +def diag(a: Union[Quantity, jax.typing.ArrayLike], k: int = 0, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -292,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def tril(a: Union[Quantity, bst.typing.ArrayLike], +def tril(a: Union[Quantity, jax.typing.ArrayLike], k: int = 0, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -322,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def triu(a: Union[Quantity, bst.typing.ArrayLike], +def triu(a: Union[Quantity, jax.typing.ArrayLike], k: int = 0, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -352,8 +354,8 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def empty_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, +def empty_like(a: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, shape: Any = None, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -385,8 +387,8 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def ones_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, +def ones_like(a: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, shape: Any = None, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -418,8 +420,8 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, +def zeros_like(a: Union[Quantity, jax.typing.ArrayLike], + dtype: Optional[jax.typing.DTypeLike] = None, shape: Any = None, unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' @@ -452,8 +454,8 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def asarray( - a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], - dtype: Optional[bst.typing.DTypeLike] = None, + a: Union[Quantity, jax.typing.ArrayLike, Sequence[Quantity]], + dtype: Optional[jax.typing.DTypeLike] = None, order: Optional[str] = None, unit: Optional[Unit] = None, ) -> Union[Quantity, jax.Array]: @@ -606,12 +608,12 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') def linspace( - start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], + start: Union[Quantity, jax.typing.ArrayLike], + stop: Union[Quantity, jax.typing.ArrayLike], num: int = 50, endpoint: Optional[bool] = True, retstep: Optional[bool] = False, - dtype: Optional[bst.typing.DTypeLike] = None + dtype: Optional[jax.typing.DTypeLike] = None ) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. @@ -643,12 +645,12 @@ def linspace( @set_module_as('brainunit.math') -def logspace(start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], +def logspace(start: Union[Quantity, jax.typing.ArrayLike], + stop: Union[Quantity, jax.typing.ArrayLike], num: Optional[int] = 50, endpoint: Optional[bool] = True, base: Optional[float] = 10.0, - dtype: Optional[bst.typing.DTypeLike] = None): + dtype: Optional[jax.typing.DTypeLike] = None): ''' Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. @@ -679,8 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], - val: Union[Quantity, bst.typing.ArrayLike], +def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike], + val: Union[Quantity, jax.typing.ArrayLike], wrap: Optional[bool] = False, inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' @@ -709,8 +711,8 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def array_split(ary: Union[Quantity, bst.typing.ArrayLike], - indices_or_sections: Union[int, bst.typing.ArrayLike], +def array_split(ary: Union[Quantity, jax.typing.ArrayLike], + indices_or_sections: Union[int, jax.typing.ArrayLike], axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: ''' Split an array into multiple sub-arrays. @@ -732,7 +734,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], +def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike], copy: Optional[bool] = True, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy'): @@ -759,7 +761,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def vander(x: Union[Quantity, bst.typing.ArrayLike], +def vander(x: Union[Quantity, jax.typing.ArrayLike], N: Optional[bool] = None, increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py index c87890a..8d06fd9 100644 --- a/brainunit/math/_compat_numpy_funcs_accept_unitless.py +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -15,7 +15,7 @@ from functools import wraps from typing import (Union) -import brainstate as bst +import jax import jax.numpy as jnp from jax import Array @@ -57,12 +57,12 @@ def f(x, *args, **kwargs): @wrap_math_funcs_only_accept_unitless_unary -def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: return jnp.exp(x) @wrap_math_funcs_only_accept_unitless_unary -def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: +def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]: return jnp.exp2(x) diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index 1325539..528b41b 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -15,12 +15,10 @@ from functools import wraps from typing import (Union) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np from jax import Array -from numpy import number from .._base import (Quantity, ) @@ -53,12 +51,12 @@ def f(x, *args, **kwargs): @wrap_elementwise_bit_operation_unary -def bitwise_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def bitwise_not(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.bitwise_not(x) @wrap_elementwise_bit_operation_unary -def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.invert(x) @@ -102,27 +100,27 @@ def f(x, y, *args, **kwargs): @wrap_elementwise_bit_operation_binary -def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def bitwise_and(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.bitwise_and(x, y) @wrap_elementwise_bit_operation_binary -def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def bitwise_or(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.bitwise_or(x, y) @wrap_elementwise_bit_operation_binary -def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def bitwise_xor(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.bitwise_xor(x, y) @wrap_elementwise_bit_operation_binary -def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def left_shift(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.left_shift(x, y) @wrap_elementwise_bit_operation_binary -def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: +def right_shift(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Array: return jnp.right_shift(x, y) diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index 227234c..65fbde0 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -16,7 +16,6 @@ from functools import wraps from typing import (Callable, Union, Optional) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -62,12 +61,12 @@ def f(x, *args, **kwargs): @wrap_math_funcs_change_unit_unary(lambda x: x ** -1) -def reciprocal(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def reciprocal(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.reciprocal(x) @wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def var(x: Union[Quantity, bst.typing.ArrayLike], +def var(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[Union[int, Sequence[int]]] = None, ddof: int = 0, keepdims: bool = False) -> Union[Quantity, jax.Array]: @@ -75,7 +74,7 @@ def var(x: Union[Quantity, bst.typing.ArrayLike], @wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def nanvar(x: Union[Quantity, bst.typing.ArrayLike], +def nanvar(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[Union[int, Sequence[int]]] = None, ddof: int = 0, keepdims: bool = False) -> Union[Quantity, jax.Array]: @@ -83,22 +82,22 @@ def nanvar(x: Union[Quantity, bst.typing.ArrayLike], @wrap_math_funcs_change_unit_unary(lambda x: x * 2 ** -1) -def frexp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def frexp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.frexp(x) @wrap_math_funcs_change_unit_unary(lambda x: x ** 0.5) -def sqrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def sqrt(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.sqrt(x) @wrap_math_funcs_change_unit_unary(lambda x: x ** (1 / 3)) -def cbrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def cbrt(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.cbrt(x) @wrap_math_funcs_change_unit_unary(lambda x: x ** 2) -def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def square(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.square(x) @@ -176,13 +175,13 @@ def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Arra @set_module_as('brainunit.math') -def prod(x: Union[Quantity, bst.typing.ArrayLike], +def prod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, + dtype: Optional[jax.typing.DTypeLike] = None, out: None = None, keepdims: Optional[bool] = False, - initial: Union[Quantity, bst.typing.ArrayLike] = None, - where: Union[Quantity, bst.typing.ArrayLike] = None, + initial: Union[Quantity, jax.typing.ArrayLike] = None, + where: Union[Quantity, jax.typing.ArrayLike] = None, promote_integers: bool = True) -> Union[Quantity, jax.Array]: ''' Return the product of array elements over a given axis. @@ -209,13 +208,13 @@ def prod(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def nanprod(x: Union[Quantity, bst.typing.ArrayLike], +def nanprod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, + dtype: Optional[jax.typing.DTypeLike] = None, out: None = None, keepdims: bool = False, - initial: Union[Quantity, bst.typing.ArrayLike] = None, - where: Union[Quantity, bst.typing.ArrayLike] = None): + initial: Union[Quantity, jax.typing.ArrayLike] = None, + where: Union[Quantity, jax.typing.ArrayLike] = None): ''' Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. @@ -241,10 +240,10 @@ def nanprod(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def cumprod(x: Union[Quantity, bst.typing.ArrayLike], +def cumprod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + dtype: Optional[jax.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, jax.typing.ArrayLike]: ''' Return the cumulative product of elements along a given axis. @@ -264,10 +263,10 @@ def cumprod(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], +def nancumprod(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + dtype: Optional[jax.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, jax.typing.ArrayLike]: ''' Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. @@ -317,37 +316,37 @@ def f(x, y, *args, **kwargs): @wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def multiply(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def multiply(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.multiply(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def divide(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.divide(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def cross(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def cross(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.cross(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x * 2 ** y) -def ldexp(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def ldexp(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.ldexp(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def true_divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def true_divide(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.true_divide(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x / y) -def divmod(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def divmod(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.divmod(x, y) @wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): +def convolve(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]): return jnp.convolve(x, y) @@ -430,8 +429,8 @@ def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.ty @set_module_as('brainunit.math') -def power(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: +def power(x: Union[Quantity, jax.typing.ArrayLike], + y: Union[Quantity, jax.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: ''' First array elements raised to powers from second array, element-wise. @@ -455,8 +454,8 @@ def power(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def floor_divide(x: Union[Quantity, jax.typing.ArrayLike], + y: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: ''' Return the largest integer smaller or equal to the division of the inputs. @@ -480,8 +479,8 @@ def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def float_power(x: Union[Quantity, bst.typing.ArrayLike], - y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: +def float_power(x: Union[Quantity, jax.typing.ArrayLike], + y: jax.typing.ArrayLike) -> Union[Quantity, jax.Array]: ''' First array elements raised to powers from second array, element-wise. @@ -503,8 +502,8 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def remainder(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def remainder(x: Union[Quantity, jax.typing.ArrayLike], + y: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: ''' Return element-wise remainder of division. diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py index 7f8d8fc..bedf603 100644 --- a/brainunit/math/_compat_numpy_funcs_indexing.py +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -14,7 +14,6 @@ # ============================================================================== from typing import (Union, Optional) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -36,8 +35,8 @@ # indexing funcs # -------------- @set_module_as('brainunit.math') -def where(condition: Union[bool, bst.typing.ArrayLike], - *args: Union[Quantity, bst.typing.ArrayLike], +def where(condition: Union[bool, jax.typing.ArrayLike], + *args: Union[Quantity, jax.typing.ArrayLike], **kwds) -> Union[Quantity, jax.Array]: condition = jnp.asarray(condition) if len(args) == 0: @@ -86,7 +85,7 @@ def where(condition: Union[bool, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], +def tril_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: ''' Return the indices for the lower-triangle of an (n, m) array. @@ -119,7 +118,7 @@ def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], +def triu_indices_from(arr: Union[Quantity, jax.typing.ArrayLike], k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: ''' Return the indices for the upper-triangle of an (n, m) array. @@ -138,8 +137,8 @@ def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def take(a: Union[Quantity, bst.typing.ArrayLike], - indices: Union[Quantity, bst.typing.ArrayLike], +def take(a: Union[Quantity, jax.typing.ArrayLike], + indices: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, mode: Optional[str] = None) -> Union[Quantity, jax.Array]: if isinstance(a, Quantity): @@ -149,8 +148,8 @@ def take(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def select(condlist: list[Union[bst.typing.ArrayLike]], - choicelist: Union[Quantity, bst.typing.ArrayLike], +def select(condlist: list[Union[jax.typing.ArrayLike]], + choicelist: Union[Quantity, jax.typing.ArrayLike], default: int = 0) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 4a6616e..8fa38d5 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -15,7 +15,6 @@ from functools import wraps from typing import (Union) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -60,172 +59,172 @@ def f(x, *args, **kwargs): @wrap_math_funcs_keep_unit_unary -def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def real(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.real(x) @wrap_math_funcs_keep_unit_unary -def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def imag(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.imag(x) @wrap_math_funcs_keep_unit_unary -def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def conj(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.conj(x) @wrap_math_funcs_keep_unit_unary -def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def conjugate(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.conjugate(x) @wrap_math_funcs_keep_unit_unary -def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def negative(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.negative(x) @wrap_math_funcs_keep_unit_unary -def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def positive(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.positive(x) @wrap_math_funcs_keep_unit_unary -def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def abs(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.abs(x) @wrap_math_funcs_keep_unit_unary -def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def round_(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.round(x) @wrap_math_funcs_keep_unit_unary -def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def around(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.around(x) @wrap_math_funcs_keep_unit_unary -def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def round(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.round(x) @wrap_math_funcs_keep_unit_unary -def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def rint(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.rint(x) @wrap_math_funcs_keep_unit_unary -def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def floor(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.floor(x) @wrap_math_funcs_keep_unit_unary -def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def ceil(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.ceil(x) @wrap_math_funcs_keep_unit_unary -def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def trunc(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.trunc(x) @wrap_math_funcs_keep_unit_unary -def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def fix(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.fix(x) @wrap_math_funcs_keep_unit_unary -def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def sum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.sum(x) @wrap_math_funcs_keep_unit_unary -def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nancumsum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nancumsum(x) @wrap_math_funcs_keep_unit_unary -def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nansum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nansum(x) @wrap_math_funcs_keep_unit_unary -def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def cumsum(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.cumsum(x) @wrap_math_funcs_keep_unit_unary -def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def ediff1d(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.ediff1d(x) @wrap_math_funcs_keep_unit_unary -def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def absolute(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.absolute(x) @wrap_math_funcs_keep_unit_unary -def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def fabs(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.fabs(x) @wrap_math_funcs_keep_unit_unary -def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def median(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.median(x) @wrap_math_funcs_keep_unit_unary -def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmin(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nanmin(x) @wrap_math_funcs_keep_unit_unary -def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmax(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nanmax(x) @wrap_math_funcs_keep_unit_unary -def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def ptp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.ptp(x) @wrap_math_funcs_keep_unit_unary -def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def average(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.average(x) @wrap_math_funcs_keep_unit_unary -def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def mean(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.mean(x) @wrap_math_funcs_keep_unit_unary -def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def std(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.std(x) @wrap_math_funcs_keep_unit_unary -def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmedian(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nanmedian(x) @wrap_math_funcs_keep_unit_unary -def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanmean(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nanmean(x) @wrap_math_funcs_keep_unit_unary -def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def nanstd(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.nanstd(x) @wrap_math_funcs_keep_unit_unary -def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def diff(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.diff(x) @wrap_math_funcs_keep_unit_unary -def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def modf(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: return jnp.modf(x) @@ -753,12 +752,12 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union # math funcs keep unit (n-ary) # ---------------------------- @set_module_as('brainunit.math') -def interp(x: Union[Quantity, bst.typing.ArrayLike], - xp: Union[Quantity, bst.typing.ArrayLike], - fp: Union[Quantity, bst.typing.ArrayLike], - left: Union[Quantity, bst.typing.ArrayLike] = None, - right: Union[Quantity, bst.typing.ArrayLike] = None, - period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: +def interp(x: Union[Quantity, jax.typing.ArrayLike], + xp: Union[Quantity, jax.typing.ArrayLike], + fp: Union[Quantity, jax.typing.ArrayLike], + left: Union[Quantity, jax.typing.ArrayLike] = None, + right: Union[Quantity, jax.typing.ArrayLike] = None, + period: Union[Quantity, jax.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: ''' One-dimensional linear interpolation. @@ -796,9 +795,9 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') -def clip(a: Union[Quantity, bst.typing.ArrayLike], - a_min: Union[Quantity, bst.typing.ArrayLike], - a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def clip(a: Union[Quantity, jax.typing.ArrayLike], + a_min: Union[Quantity, jax.typing.ArrayLike], + a_max: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: ''' Clip (limit) the values in an array. diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py index e7d69e7..f411207 100644 --- a/brainunit/math/_compat_numpy_funcs_logic.py +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -15,7 +15,6 @@ from functools import wraps from typing import (Union, Optional) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -54,21 +53,21 @@ def f(x, *args, **kwargs): @wrap_logic_func_unary -def all(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, +def all(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, out: Optional[Array] = None, keepdims: bool = False, where: Optional[Array] = None) -> Union[bool, Array]: return jnp.all(x, axis=axis, out=out, keepdims=keepdims, where=where) @wrap_logic_func_unary -def any(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, +def any(x: Union[Quantity, jax.typing.ArrayLike], axis: Optional[int] = None, out: Optional[Array] = None, keepdims: bool = False, where: Optional[Array] = None) -> Union[bool, Array]: return jnp.any(x, axis=axis, out=out, keepdims=keepdims, where=where) @wrap_logic_func_unary -def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def logical_not(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.logical_not(x) @@ -135,67 +134,67 @@ def f(x, y, *args, **kwargs): @wrap_logic_func_binary -def equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def equal(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.equal(x, y) @wrap_logic_func_binary -def not_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def not_equal(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.not_equal(x, y) @wrap_logic_func_binary -def greater(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def greater(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.greater(x, y) @wrap_logic_func_binary -def greater_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def greater_equal(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.greater_equal(x, y) @wrap_logic_func_binary -def less(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def less(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.less(x, y) @wrap_logic_func_binary -def less_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: +def less_equal(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[bool, Array]: return jnp.less_equal(x, y) @wrap_logic_func_binary -def array_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ +def array_equal(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[ bool, Array]: return jnp.array_equal(x, y) @wrap_logic_func_binary -def isclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], +def isclose(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) @wrap_logic_func_binary -def allclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], +def allclose(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) @wrap_logic_func_binary -def logical_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ +def logical_and(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[ bool, Array]: return jnp.logical_and(x, y) @wrap_logic_func_binary -def logical_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ +def logical_or(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[ bool, Array]: return jnp.logical_or(x, y) @wrap_logic_func_binary -def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ +def logical_xor(x: Union[Quantity, jax.typing.ArrayLike], y: Union[Quantity, jax.typing.ArrayLike]) -> Union[ bool, Array]: return jnp.logical_xor(x, y) diff --git a/brainunit/math/_compat_numpy_get_attribute.py b/brainunit/math/_compat_numpy_get_attribute.py index 03bec0d..7a065b1 100644 --- a/brainunit/math/_compat_numpy_get_attribute.py +++ b/brainunit/math/_compat_numpy_get_attribute.py @@ -14,7 +14,6 @@ # ============================================================================== from typing import (Union) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -31,7 +30,7 @@ @set_module_as('brainunit.math') -def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: +def ndim(a: Union[Quantity, jax.typing.ArrayLike]) -> int: ''' Return the number of dimensions of an array. @@ -48,7 +47,7 @@ def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: @set_module_as('brainunit.math') -def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: +def isreal(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: ''' Return True if the input array is real. @@ -65,7 +64,7 @@ def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: @set_module_as('brainunit.math') -def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: +def isscalar(a: Union[Quantity, jax.typing.ArrayLike]) -> bool: ''' Return True if the input is a scalar. @@ -82,7 +81,7 @@ def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: @set_module_as('brainunit.math') -def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: +def isfinite(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: ''' Return each element of the array is finite or not. @@ -99,7 +98,7 @@ def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: @set_module_as('brainunit.math') -def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: +def isinf(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: ''' Return each element of the array is infinite or not. @@ -116,7 +115,7 @@ def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: @set_module_as('brainunit.math') -def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: +def isnan(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: ''' Return each element of the array is NaN or not. @@ -133,7 +132,7 @@ def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: @set_module_as('brainunit.math') -def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: +def shape(a: Union[Quantity, jax.typing.ArrayLike]) -> tuple[int, ...]: """ Return the shape of an array. @@ -173,7 +172,7 @@ def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: @set_module_as('brainunit.math') -def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: +def size(a: Union[Quantity, jax.typing.ArrayLike], axis: int = None) -> int: """ Return the number of elements along a given axis. diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 0deb591..881ba40 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -15,7 +15,6 @@ from collections.abc import Sequence from typing import (Callable, Union, Tuple) -import brainstate as bst import jax import jax.numpy as jnp import numpy as np @@ -59,7 +58,7 @@ @set_module_as('brainunit.math') -def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: +def finfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.finfo: if isinstance(a, Quantity): return jnp.finfo(a.value) else: @@ -67,7 +66,7 @@ def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: @set_module_as('brainunit.math') -def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: +def iinfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.iinfo: if isinstance(a, Quantity): return jnp.iinfo(a.value) else: @@ -77,7 +76,7 @@ def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: # more # ---- @set_module_as('brainunit.math') -def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, list[Array]]: +def broadcast_arrays(*args: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, list[Array]]: from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(arg, Quantity) for arg in args): @@ -184,8 +183,8 @@ def einsum( @set_module_as('brainunit.math') def gradient( - f: Union[bst.typing.ArrayLike, Quantity], - *varargs: Union[bst.typing.ArrayLike, Quantity], + f: Union[jax.typing.ArrayLike, Quantity], + *varargs: Union[jax.typing.ArrayLike, Quantity], axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: @@ -225,8 +224,8 @@ def gradient( @set_module_as('brainunit.math') def intersect1d( - ar1: Union[bst.typing.ArrayLike], - ar2: Union[bst.typing.ArrayLike], + ar1: Union[jax.typing.ArrayLike], + ar2: Union[jax.typing.ArrayLike], assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: @@ -262,30 +261,30 @@ def intersect1d( @wrap_math_funcs_keep_unit_unary -def nan_to_num(x: Union[bst.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, +def nan_to_num(x: Union[jax.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, neginf: float = -jnp.inf) -> Union[jax.Array, Quantity]: return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) @wrap_math_funcs_keep_unit_unary -def rot90(m: Union[bst.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ +def rot90(m: Union[jax.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ jax.Array, Quantity]: return jnp.rot90(m, k=k, axes=axes) @wrap_math_funcs_change_unit_binary(lambda x, y: x * y) -def tensordot(a: Union[bst.typing.ArrayLike, Quantity], b: Union[bst.typing.ArrayLike, Quantity], +def tensordot(a: Union[jax.typing.ArrayLike, Quantity], b: Union[jax.typing.ArrayLike, Quantity], axes: Union[int, Tuple[int, int]] = 2) -> Union[jax.Array, Quantity]: return jnp.tensordot(a, b, axes=axes) @_compatible_with_quantity(return_quantity=False) -def nanargmax(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: +def nanargmax(a: Union[jax.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: return jnp.nanargmax(a, axis=axis) @_compatible_with_quantity(return_quantity=False) -def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: +def nanargmin(a: Union[jax.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: return jnp.nanargmin(a, axis=axis) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 615d720..489e140 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -23,8 +23,8 @@ import brainunit as bu from brainunit import DimensionMismatchError from brainunit._base import Quantity -from brainunit._unit_shortcuts import ms, mV from brainunit._unit_common import second +from brainunit._unit_shortcuts import ms, mV bst.environ.set(precision=64) @@ -162,7 +162,6 @@ def test_asarray(self): result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) - def test_arange(self): result = bu.math.arange(5) self.assertEqual(result.shape, (5,)) diff --git a/pyproject.toml b/pyproject.toml index 4dc911f..c05ae5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "numpy", 'jax', 'jaxlib', 'brainstate'] +requires = ["setuptools", "numpy", 'jax', 'jaxlib'] build-backend = "setuptools.build_meta" @@ -48,7 +48,6 @@ dependencies = [ 'jax', 'jaxlib', 'numpy', - 'brainstate', ] dynamic = ['version'] diff --git a/requirements-dev.txt b/requirements-dev.txt index 85a3f7a..fddbe06 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ -r requirements.txt brainpy +brainstate # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index 0eab13b..e8ecbc4 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,5 +1,6 @@ -r requirements.txt matplotlib +brainstate # document requirements pandoc diff --git a/requirements.txt b/requirements.txt index c82c553..fc7c970 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy jax jaxlib -brainstate + diff --git a/setup.py b/setup.py index 4eb2d59..f0d8274 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax', 'brainstate'], + install_requires=['numpy>=1.15', 'jax'], url='https://github.com/brainpy/brainunit', project_urls={ "Bug Tracker": "https://github.com/brainpy/brainunit/issues",