Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove the dependency of brainstate #10

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
</p>


[``brainunit``](https://github.com/brainpy/brainunit) provides common toolboxes for brain dynamics programming (BDP).
[``brainunit``](https://github.com/brainpy/brainunit) provides a unit-aware mathematical system for brain dynamics programming (BDP).


## Installation
Expand Down
31 changes: 18 additions & 13 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,11 @@
from contextlib import contextmanager
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax.interpreters.partial_eval import DynamicJaxprTracer

from ._misc import get_dtype


from jax.tree_util import register_pytree_node_class

__all__ = [
'Quantity',
Expand Down Expand Up @@ -755,7 +750,7 @@ def in_best_unit(x, precision=None):
def array_with_unit(
floatval,
unit: Dimension,
dtype: bst.typing.DTypeLike = None
dtype: jax.typing.DTypeLike = None
) -> 'Quantity':
"""
Create a new `Array` with the given dimensions. Calls
Expand Down Expand Up @@ -961,7 +956,7 @@ class Quantity(object):
def __init__(
self,
value: Any,
dtype: Optional[bst.typing.DTypeLike] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
dim: Dimension = DIMENSIONLESS,
unit: Optional['Unit'] = None,
):
Expand All @@ -987,17 +982,14 @@ def __init__(

# array value
if isinstance(value, Quantity):
dtype = dtype or get_dtype(value)
self._dim = value.dim
self._value = jnp.array(value.value, dtype=dtype)
return

elif isinstance(value, (np.ndarray, jax.Array)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jnp.number, numbers.Number)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
Expand Down Expand Up @@ -1279,7 +1271,20 @@ def _check_tracer(self):
@property
def dtype(self):
"""Variable dtype."""
return get_dtype(self._value)
a = self._value
if hasattr(a, 'dtype'):
return a.dtype
else:
if isinstance(a, bool):
return bool
elif isinstance(a, int):
return jax.dtypes.canonicalize_dtype(int)
elif isinstance(a, float):
return jax.dtypes.canonicalize_dtype(float)
elif isinstance(a, complex):
return jax.dtypes.canonicalize_dtype(complex)
else:
raise TypeError(f'Can not get dtype of {a}.')

@property
def shape(self) -> Tuple[int, ...]:
Expand Down Expand Up @@ -2480,7 +2485,7 @@ def __init__(
name: str = None,
dispname: str = None,
iscompound: bool = None,
dtype: bst.typing.DTypeLike = None,
dtype: jax.typing.DTypeLike = None,
):
if dim is None:
dim = DIMENSIONLESS
Expand Down
38 changes: 0 additions & 38 deletions brainunit/_misc.py

This file was deleted.

3 changes: 1 addition & 2 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import itertools
import warnings

import brainstate as bst
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_equal

import brainstate as bst

array = np.array
bst.environ.set(precision=64)

Expand Down
74 changes: 38 additions & 36 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Sequence
from functools import wraps
from typing import (Callable, Union, Optional, Any)
from typing import (Union, Optional, Any)

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from brainstate._utils import set_module_as
from jax import Array

from .._base import (DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)
from .._base import (
DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)

__all__ = [
# array creation
Expand Down Expand Up @@ -231,10 +231,12 @@ def zeros(


@set_module_as('brainunit.math')
def full_like(a: Union[Quantity, bst.typing.ArrayLike],
fill_value: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
shape: Any = None) -> Union[Quantity, jax.Array]:
def full_like(
a: Union[Quantity, jax.typing.ArrayLike],
fill_value: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity.
else return an array of `a` filled with `fill_value`.
Expand Down Expand Up @@ -262,7 +264,7 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def diag(a: Union[Quantity, bst.typing.ArrayLike],
def diag(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -292,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def tril(a: Union[Quantity, bst.typing.ArrayLike],
def tril(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -322,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def triu(a: Union[Quantity, bst.typing.ArrayLike],
def triu(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -352,8 +354,8 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def empty_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def empty_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -385,8 +387,8 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def ones_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def ones_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -418,8 +420,8 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def zeros_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -452,8 +454,8 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],

@set_module_as('brainunit.math')
def asarray(
a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[bst.typing.DTypeLike] = None,
a: Union[Quantity, jax.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
order: Optional[str] = None,
unit: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
Expand Down Expand Up @@ -606,12 +608,12 @@ def arange(*args, **kwargs):

@set_module_as('brainunit.math')
def linspace(
start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: int = 50,
endpoint: Optional[bool] = True,
retstep: Optional[bool] = False,
dtype: Optional[bst.typing.DTypeLike] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided.
Expand Down Expand Up @@ -643,12 +645,12 @@ def linspace(


@set_module_as('brainunit.math')
def logspace(start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
def logspace(start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: Optional[int] = 50,
endpoint: Optional[bool] = True,
base: Optional[float] = 10.0,
dtype: Optional[bst.typing.DTypeLike] = None):
dtype: Optional[jax.typing.DTypeLike] = None):
'''
Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided.

Expand Down Expand Up @@ -679,8 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
val: Union[Quantity, bst.typing.ArrayLike],
def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike],
val: Union[Quantity, jax.typing.ArrayLike],
wrap: Optional[bool] = False,
inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -709,8 +711,8 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
indices_or_sections: Union[int, bst.typing.ArrayLike],
def array_split(ary: Union[Quantity, jax.typing.ArrayLike],
indices_or_sections: Union[int, jax.typing.ArrayLike],
axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]:
'''
Split an array into multiple sub-arrays.
Expand All @@ -732,7 +734,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],
def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike],
copy: Optional[bool] = True,
sparse: Optional[bool] = False,
indexing: Optional[str] = 'xy'):
Expand All @@ -759,7 +761,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def vander(x: Union[Quantity, bst.typing.ArrayLike],
def vander(x: Union[Quantity, jax.typing.ArrayLike],
N: Optional[bool] = None,
increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
Expand Down
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_funcs_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import wraps
from typing import (Union)

import brainstate as bst
import jax
import jax.numpy as jnp
from jax import Array

Expand Down Expand Up @@ -57,12 +57,12 @@ def f(x, *args, **kwargs):


@wrap_math_funcs_only_accept_unitless_unary
def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp(x)


@wrap_math_funcs_only_accept_unitless_unary
def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp2(x)


Expand Down
Loading