diff --git a/brainunit/_base.py b/brainunit/_base.py index 8640d87..95df305 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -49,22 +49,10 @@ ] _all_slice = slice(None, None, None) -random = None _unit_checking = True -_automatically_register_units = True _allow_python_scalar_value = False -@contextmanager -def turn_off_unit_register(): - try: - global _automatically_register_units - _automatically_register_units = False - yield - finally: - _automatically_register_units = True - - @contextmanager def allow_python_scalar(): try: @@ -85,13 +73,6 @@ def turn_off_unit_checking(): _unit_checking = True -def _get_random_module(): - global random - if random is None: - from brainstate import random - return random - - def _to_quantity(array): if isinstance(array, Quantity): return array @@ -2222,18 +2203,45 @@ def view(self, *args, dtype=None) -> 'Quantity': # NumPy support # ------------------ - def to_numpy(self, dtype=None) -> np.ndarray: - """Convert to numpy.ndarray.""" - return np.asarray(self.value, dtype=dtype) + def to_numpy(self, + dtype: Optional[jax.typing.DTypeLike] = None, + unit: Optional['Unit'] = None) -> np.ndarray: + """ + Remove the unit and convert to ``numpy.ndarray``. - def to_jax(self, dtype=None) -> jax.Array: - """Convert to jax.numpy.ndarray.""" - if dtype is None: - return self.value + Args: + dtype: The data type of the output array. + unit: The unit of the output array. + + Returns: + The numpy.ndarray. + """ + if unit is None: + return np.asarray(self.value, dtype=dtype) else: + assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" + return np.asarray(self / unit, dtype=dtype) + + def to_jax(self, + dtype: Optional[jax.typing.DTypeLike] = None, + unit: Optional['Unit'] = None) -> jax.Array: + """ + Remove the unit and convert to ``jax.Array``. + + Args: + dtype: The data type of the output array. + unit: The unit of the output array. + + Returns: + The jax.Array. + """ + if unit is None: return jnp.asarray(self.value, dtype=dtype) + else: + assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" + return jnp.asarray(self / unit, dtype=dtype) - def __array__(self, dtype=None) -> np.ndarray: + def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" return np.asarray(self.value, dtype=dtype) @@ -2247,7 +2255,7 @@ def __index__(self): # PyTorch compatibility # ---------------------- - def unsqueeze(self, dim: int) -> 'Quantity': + def unsqueeze(self, axis: int) -> 'Quantity': """ Array.unsqueeze(dim) -> Array, or so called Tensor equals @@ -2255,7 +2263,7 @@ def unsqueeze(self, dim: int) -> 'Quantity': See :func:`brainstate.math.unsqueeze` """ - return Quantity(jnp.expand_dims(self.value, dim), dim=self.dim) + return Quantity(jnp.expand_dims(self.value, axis), dim=self.dim) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity': """ @@ -2513,9 +2521,6 @@ def __init__( super().__init__(value, dtype=dtype, dim=dim) - if _automatically_register_units: - register_new_unit(self) - @staticmethod def create(unit: Dimension, name: str, dispname: str, scale: int = 0): """ diff --git a/brainunit/_unit_common.py b/brainunit/_unit_common.py index d21b728..5626463 100644 --- a/brainunit/_unit_common.py +++ b/brainunit/_unit_common.py @@ -20,7 +20,6 @@ additional_unit_register, get_or_create_dimension, standard_unit_register, - turn_off_unit_register, allow_python_scalar ) @@ -2097,7 +2096,7 @@ "celsius" # Dummy object raising an error ] -with turn_off_unit_register(), allow_python_scalar(): +with allow_python_scalar(): #### FUNDAMENTAL UNITS metre = Unit.create(get_or_create_dimension(m=1), "metre", "m") meter = Unit.create(get_or_create_dimension(m=1), "meter", "m") diff --git a/brainunit/_unit_constants.py b/brainunit/_unit_constants.py index 914c33d..85136f5 100644 --- a/brainunit/_unit_constants.py +++ b/brainunit/_unit_constants.py @@ -41,8 +41,6 @@ import numpy as np -from ._base import (turn_off_unit_register, - allow_python_scalar) from ._unit_common import ( amp, coulomb, @@ -69,24 +67,23 @@ 'zero_celsius', ] -with turn_off_unit_register(), allow_python_scalar(): - #: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na) - avogadro_constant = 6.022140857e23 / mole - #: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k) - boltzmann_constant = 1.38064852e-23 * joule / kelvin - #: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0) - electric_constant = 8.854187817e-12 * farad / meter - #: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me) - electron_mass = 9.10938356e-31 * kilogram - #: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e) - elementary_charge = 1.6021766208e-19 * coulomb - #: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f) - faraday_constant = 96485.33289 * coulomb / mole - #: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r) - gas_constant = 8.3144598 * joule / mole / kelvin - #: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0) - magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2 - #: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu) - molar_mass_constant = 1 * gram / mole - #: zero degree Celsius - zero_celsius = 273.15 * kelvin +#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na) +avogadro_constant = 6.022140857e23 / mole +#: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k) +boltzmann_constant = 1.38064852e-23 * joule / kelvin +#: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0) +electric_constant = 8.854187817e-12 * farad / meter +#: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me) +electron_mass = 9.10938356e-31 * kilogram +#: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e) +elementary_charge = 1.6021766208e-19 * coulomb +#: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f) +faraday_constant = 96485.33289 * coulomb / mole +#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r) +gas_constant = 8.3144598 * joule / mole / kelvin +#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0) +magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2 +#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu) +molar_mass_constant = 1 * gram / mole +#: zero degree Celsius +zero_celsius = 273.15 * kelvin diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 60c6dd3..fe5c2d7 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -17,12 +17,14 @@ import warnings import brainstate as bst +import jax import jax.numpy as jnp import numpy as np import pytest from numpy.testing import assert_equal -array = np.array +import brainunit as bu + bst.environ.set(precision=64) from brainunit._unit_common import * @@ -1623,6 +1625,7 @@ def test_constants(): (constants.avogadro_constant * constants.elementary_charge).value, ) + # if __name__ == "__main__": # test_construction() # test_get_dimensions() @@ -1659,3 +1662,23 @@ def test_constants(): # test_units_vs_quantities() # test_all_units_list() # test_constants() + + +def test_jit_array(): + @jax.jit + def f1(a): + b = a * bu.siemens / bu.cm ** 2 + return b + + val = np.random.rand(3) + r = f1(val) + bu.math.allclose(val * bu.siemens / bu.cm ** 2, r) + + @jax.jit + def f2(a): + a = a + 1. * bu.siemens / bu.cm ** 2 + return a + + val = np.random.rand(3) * bu.siemens / bu.cm ** 2 + r = f2(val) + bu.math.allclose(val + 1 * bu.siemens / bu.cm ** 2, r)