Skip to content

Commit

Permalink
remove the dependency of brainstate
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 12, 2024
1 parent 4cc0a8e commit dd87871
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 235 deletions.
31 changes: 18 additions & 13 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,11 @@
from contextlib import contextmanager
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax.interpreters.partial_eval import DynamicJaxprTracer

from ._misc import get_dtype


from jax.tree_util import register_pytree_node_class

__all__ = [
'Quantity',
Expand Down Expand Up @@ -755,7 +750,7 @@ def in_best_unit(x, precision=None):
def array_with_unit(
floatval,
unit: Dimension,
dtype: bst.typing.DTypeLike = None
dtype: jax.typing.DTypeLike = None
) -> 'Quantity':
"""
Create a new `Array` with the given dimensions. Calls
Expand Down Expand Up @@ -961,7 +956,7 @@ class Quantity(object):
def __init__(
self,
value: Any,
dtype: Optional[bst.typing.DTypeLike] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
dim: Dimension = DIMENSIONLESS,
unit: Optional['Unit'] = None,
):
Expand All @@ -987,17 +982,14 @@ def __init__(

# array value
if isinstance(value, Quantity):
dtype = dtype or get_dtype(value)
self._dim = value.dim
self._value = jnp.array(value.value, dtype=dtype)
return

elif isinstance(value, (np.ndarray, jax.Array)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jnp.number, numbers.Number)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
Expand Down Expand Up @@ -1279,7 +1271,20 @@ def _check_tracer(self):
@property
def dtype(self):
"""Variable dtype."""
return get_dtype(self._value)
a = self._value
if hasattr(a, 'dtype'):
return a.dtype
else:
if isinstance(a, bool):
return bool
elif isinstance(a, int):
return jax.dtypes.canonicalize_dtype(int)
elif isinstance(a, float):
return jax.dtypes.canonicalize_dtype(float)
elif isinstance(a, complex):
return jax.dtypes.canonicalize_dtype(complex)
else:
raise TypeError(f'Can not get dtype of {a}.')

@property
def shape(self) -> Tuple[int, ...]:
Expand Down Expand Up @@ -2480,7 +2485,7 @@ def __init__(
name: str = None,
dispname: str = None,
iscompound: bool = None,
dtype: bst.typing.DTypeLike = None,
dtype: jax.typing.DTypeLike = None,
):
if dim is None:
dim = DIMENSIONLESS
Expand Down
38 changes: 0 additions & 38 deletions brainunit/_misc.py

This file was deleted.

3 changes: 1 addition & 2 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import itertools
import warnings

import brainstate as bst
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_equal

import brainstate as bst

array = np.array
bst.environ.set(precision=64)

Expand Down
74 changes: 38 additions & 36 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Sequence
from functools import wraps
from typing import (Callable, Union, Optional, Any)
from typing import (Union, Optional, Any)

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
from brainstate._utils import set_module_as
from jax import Array

from .._base import (DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)
from .._base import (
DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
)

__all__ = [
# array creation
Expand Down Expand Up @@ -231,10 +231,12 @@ def zeros(


@set_module_as('brainunit.math')
def full_like(a: Union[Quantity, bst.typing.ArrayLike],
fill_value: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
shape: Any = None) -> Union[Quantity, jax.Array]:
def full_like(
a: Union[Quantity, jax.typing.ArrayLike],
fill_value: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity.
else return an array of `a` filled with `fill_value`.
Expand Down Expand Up @@ -262,7 +264,7 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def diag(a: Union[Quantity, bst.typing.ArrayLike],
def diag(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -292,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def tril(a: Union[Quantity, bst.typing.ArrayLike],
def tril(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -322,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def triu(a: Union[Quantity, bst.typing.ArrayLike],
def triu(a: Union[Quantity, jax.typing.ArrayLike],
k: int = 0,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -352,8 +354,8 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def empty_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def empty_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -385,8 +387,8 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def ones_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def ones_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -418,8 +420,8 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],
dtype: Optional[bst.typing.DTypeLike] = None,
def zeros_like(a: Union[Quantity, jax.typing.ArrayLike],
dtype: Optional[jax.typing.DTypeLike] = None,
shape: Any = None,
unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -452,8 +454,8 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike],

@set_module_as('brainunit.math')
def asarray(
a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[bst.typing.DTypeLike] = None,
a: Union[Quantity, jax.typing.ArrayLike, Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
order: Optional[str] = None,
unit: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
Expand Down Expand Up @@ -606,12 +608,12 @@ def arange(*args, **kwargs):

@set_module_as('brainunit.math')
def linspace(
start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: int = 50,
endpoint: Optional[bool] = True,
retstep: Optional[bool] = False,
dtype: Optional[bst.typing.DTypeLike] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Quantity, jax.Array]:
'''
Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided.
Expand Down Expand Up @@ -643,12 +645,12 @@ def linspace(


@set_module_as('brainunit.math')
def logspace(start: Union[Quantity, bst.typing.ArrayLike],
stop: Union[Quantity, bst.typing.ArrayLike],
def logspace(start: Union[Quantity, jax.typing.ArrayLike],
stop: Union[Quantity, jax.typing.ArrayLike],
num: Optional[int] = 50,
endpoint: Optional[bool] = True,
base: Optional[float] = 10.0,
dtype: Optional[bst.typing.DTypeLike] = None):
dtype: Optional[jax.typing.DTypeLike] = None):
'''
Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided.
Expand Down Expand Up @@ -679,8 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
val: Union[Quantity, bst.typing.ArrayLike],
def fill_diagonal(a: Union[Quantity, jax.typing.ArrayLike],
val: Union[Quantity, jax.typing.ArrayLike],
wrap: Optional[bool] = False,
inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
Expand Down Expand Up @@ -709,8 +711,8 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
indices_or_sections: Union[int, bst.typing.ArrayLike],
def array_split(ary: Union[Quantity, jax.typing.ArrayLike],
indices_or_sections: Union[int, jax.typing.ArrayLike],
axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]:
'''
Split an array into multiple sub-arrays.
Expand All @@ -732,7 +734,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],
def meshgrid(*xi: Union[Quantity, jax.typing.ArrayLike],
copy: Optional[bool] = True,
sparse: Optional[bool] = False,
indexing: Optional[str] = 'xy'):
Expand All @@ -759,7 +761,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],


@set_module_as('brainunit.math')
def vander(x: Union[Quantity, bst.typing.ArrayLike],
def vander(x: Union[Quantity, jax.typing.ArrayLike],
N: Optional[bool] = None,
increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
Expand Down
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_funcs_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import wraps
from typing import (Union)

import brainstate as bst
import jax
import jax.numpy as jnp
from jax import Array

Expand Down Expand Up @@ -57,12 +57,12 @@ def f(x, *args, **kwargs):


@wrap_math_funcs_only_accept_unitless_unary
def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp(x)


@wrap_math_funcs_only_accept_unitless_unary
def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]:
def exp2(x: Union[Quantity, jax.typing.ArrayLike]) -> Union[Array, Quantity]:
return jnp.exp2(x)


Expand Down
Loading

0 comments on commit dd87871

Please sign in to comment.