Skip to content

Commit

Permalink
remove the dependency of brainstate (#10)
Browse files Browse the repository at this point in the history
* remove the dependency of `brainstate`

* update readme
chaoming0625 authored Jun 12, 2024
1 parent 4cc0a8e commit 4c1a377
Showing 19 changed files with 196 additions and 236 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
</p>


[``brainunit``](https://github.com/brainpy/brainunit) provides common toolboxes for brain dynamics programming (BDP).
[``brainunit``](https://github.com/brainpy/brainunit) provides a unit-aware mathematical system for brain dynamics programming (BDP).


## Installation
31 changes: 18 additions & 13 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
@@ -22,16 +22,11 @@
from contextlib import contextmanager
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax.interpreters.partial_eval import DynamicJaxprTracer

from ._misc import get_dtype


from jax.tree_util import register_pytree_node_class

__all__ = [
'Quantity',
@@ -755,7 +750,7 @@ def in_best_unit(x, precision=None):
def array_with_unit(
floatval,
unit: Dimension,
dtype: bst.typing.DTypeLike = None
dtype: jax.typing.DTypeLike = None
) -> 'Quantity':
"""
Create a new `Array` with the given dimensions. Calls
@@ -961,7 +956,7 @@ class Quantity(object):
def __init__(
self,
value: Any,
dtype: Optional[bst.typing.DTypeLike] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
dim: Dimension = DIMENSIONLESS,
unit: Optional['Unit'] = None,
):
@@ -987,17 +982,14 @@ def __init__(

# array value
if isinstance(value, Quantity):
dtype = dtype or get_dtype(value)
self._dim = value.dim
self._value = jnp.array(value.value, dtype=dtype)
return

elif isinstance(value, (np.ndarray, jax.Array)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jnp.number, numbers.Number)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
@@ -1279,7 +1271,20 @@ def _check_tracer(self):
@property
def dtype(self):
"""Variable dtype."""
return get_dtype(self._value)
a = self._value
if hasattr(a, 'dtype'):
return a.dtype
else:
if isinstance(a, bool):
return bool
elif isinstance(a, int):
return jax.dtypes.canonicalize_dtype(int)
elif isinstance(a, float):
return jax.dtypes.canonicalize_dtype(float)
elif isinstance(a, complex):
return jax.dtypes.canonicalize_dtype(complex)
else:
raise TypeError(f'Can not get dtype of {a}.')

@property
def shape(self) -> Tuple[int, ...]:
@@ -2480,7 +2485,7 @@ def __init__(
name: str = None,
dispname: str = None,
iscompound: bool = None,
dtype: bst.typing.DTypeLike = None,
dtype: jax.typing.DTypeLike = None,
):
if dim is None:
dim = DIMENSIONLESS
38 changes: 0 additions & 38 deletions brainunit/_misc.py

This file was deleted.

3 changes: 1 addition & 2 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,12 @@
import itertools
import warnings

import brainstate as bst
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_equal

import brainstate as bst

array = np.array
bst.environ.set(precision=64)

74 changes: 38 additions & 36 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
@@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Sequence
from functools import wraps
from typing import (Callable, Union, Optional, Any)
from typing import (Union, Optional, Any)

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from brainstate._utils import set_module_as
from jax import Array

from .._base import (DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)
from .._base import (
DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)

__all__ = [
# array creation
@@ -231,10 +231,12 @@ def zeros(


@set_module_as('brainunit.math')
def full_like(a: Union[Quantity, bst.typing.ArrayLike],
fill_value: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
shape: Any = None) -> Union[Quantity, jax.Array]:
def full_like(
a: Union[Quantity, jax.typing.ArrayLike],
fill_value: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity.
else return an array of `a` filled with `fill_value`.
@@ -262,7 +264,7 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def diag(a: Union[Quantity, bst.typing.ArrayLike],
def diag(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -292,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def tril(a: Union[Quantity, bst.typing.ArrayLike],
def tril(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -322,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def triu(a: Union[Quantity, bst.typing.ArrayLike],
def triu(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -352,8 +354,8 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def empty_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def empty_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -385,8 +387,8 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def ones_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def ones_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -418,8 +420,8 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def zeros_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
@@ -452,8 +454,8 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],

@set_module_as('brainunit.math')
def asarray(
a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[bst.typing.DTypeLike] = None,
a: Union[Quantity, jax.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
order: Optional[str] = None,
unit: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
@@ -606,12 +608,12 @@ def arange(*args, **kwargs):

@set_module_as('brainunit.math')
def linspace(
start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: int = 50,
endpoint: Optional[bool] = True,
retstep: Optional[bool] = False,
dtype: Optional[bst.typing.DTypeLike] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided.
@@ -643,12 +645,12 @@ def linspace(


@set_module_as('brainunit.math')
def logspace(start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
def logspace(start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: Optional[int] = 50,
endpoint: Optional[bool] = True,
base: Optional[float] = 10.0,
dtype: Optional[bst.typing.DTypeLike] = None):
dtype: Optional[jax.typing.DTypeLike] = None):
'''
Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided.
@@ -679,8 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
val: Union[Quantity, bst.typing.ArrayLike],
def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike],
val: Union[Quantity, jax.typing.ArrayLike],
wrap: Optional[bool] = False,
inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
@@ -709,8 +711,8 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
indices_or_sections: Union[int, bst.typing.ArrayLike],
def array_split(ary: Union[Quantity, jax.typing.ArrayLike],
indices_or_sections: Union[int, jax.typing.ArrayLike],
axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]:
'''
Split an array into multiple sub-arrays.
@@ -732,7 +734,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],
def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike],
copy: Optional[bool] = True,
sparse: Optional[bool] = False,
indexing: Optional[str] = 'xy'):
@@ -759,7 +761,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def vander(x: Union[Quantity, bst.typing.ArrayLike],
def vander(x: Union[Quantity, jax.typing.ArrayLike],
N: Optional[bool] = None,
increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_funcs_accept_unitless.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from functools import wraps
from typing import (Union)

import brainstate as bst
import jax
import jax.numpy as jnp
from jax import Array

@@ -57,12 +57,12 @@ def f(x, *args, **kwargs):


@wrap_math_funcs_only_accept_unitless_unary
def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp(x)


@wrap_math_funcs_only_accept_unitless_unary
def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp2(x)


Loading

0 comments on commit 4c1a377

Please sign in to comment.