Skip to content

Commit

Permalink
remove the default unit register (#14)
Browse files Browse the repository at this point in the history
* remove the default unit register

* remove useless random module import

* `to_jax()` and `to_numpy()` support to receive a `unit` argument
  • Loading branch information
chaoming0625 authored Jun 13, 2024
1 parent eac93bd commit 3d25cce
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 58 deletions.
69 changes: 37 additions & 32 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,10 @@
]

_all_slice = slice(None, None, None)
random = None
_unit_checking = True
_automatically_register_units = True
_allow_python_scalar_value = False


@contextmanager
def turn_off_unit_register():
try:
global _automatically_register_units
_automatically_register_units = False
yield
finally:
_automatically_register_units = True


@contextmanager
def allow_python_scalar():
try:
Expand All @@ -85,13 +73,6 @@ def turn_off_unit_checking():
_unit_checking = True


def _get_random_module():
global random
if random is None:
from brainstate import random
return random


def _to_quantity(array):
if isinstance(array, Quantity):
return array
Expand Down Expand Up @@ -2222,18 +2203,45 @@ def view(self, *args, dtype=None) -> 'Quantity':
# NumPy support
# ------------------

def to_numpy(self, dtype=None) -> np.ndarray:
"""Convert to numpy.ndarray."""
return np.asarray(self.value, dtype=dtype)
def to_numpy(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> np.ndarray:
"""
Remove the unit and convert to ``numpy.ndarray``.
def to_jax(self, dtype=None) -> jax.Array:
"""Convert to jax.numpy.ndarray."""
if dtype is None:
return self.value
Args:
dtype: The data type of the output array.
unit: The unit of the output array.
Returns:
The numpy.ndarray.
"""
if unit is None:
return np.asarray(self.value, dtype=dtype)
else:
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return np.asarray(self / unit, dtype=dtype)

def to_jax(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> jax.Array:
"""
Remove the unit and convert to ``jax.Array``.
Args:
dtype: The data type of the output array.
unit: The unit of the output array.
Returns:
The jax.Array.
"""
if unit is None:
return jnp.asarray(self.value, dtype=dtype)
else:
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return jnp.asarray(self / unit, dtype=dtype)

def __array__(self, dtype=None) -> np.ndarray:
def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray:
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value, dtype=dtype)

Expand All @@ -2247,15 +2255,15 @@ def __index__(self):
# PyTorch compatibility
# ----------------------

def unsqueeze(self, dim: int) -> 'Quantity':
def unsqueeze(self, axis: int) -> 'Quantity':
"""
Array.unsqueeze(dim) -> Array, or so called Tensor
equals
Array.expand_dims(dim)
See :func:`brainstate.math.unsqueeze`
"""
return Quantity(jnp.expand_dims(self.value, dim), dim=self.dim)
return Quantity(jnp.expand_dims(self.value, axis), dim=self.dim)

def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity':
"""
Expand Down Expand Up @@ -2513,9 +2521,6 @@ def __init__(

super().__init__(value, dtype=dtype, dim=dim)

if _automatically_register_units:
register_new_unit(self)

@staticmethod
def create(unit: Dimension, name: str, dispname: str, scale: int = 0):
"""
Expand Down
3 changes: 1 addition & 2 deletions brainunit/_unit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
additional_unit_register,
get_or_create_dimension,
standard_unit_register,
turn_off_unit_register,
allow_python_scalar
)

Expand Down Expand Up @@ -2097,7 +2096,7 @@
"celsius" # Dummy object raising an error
]

with turn_off_unit_register(), allow_python_scalar():
with allow_python_scalar():
#### FUNDAMENTAL UNITS
metre = Unit.create(get_or_create_dimension(m=1), "metre", "m")
meter = Unit.create(get_or_create_dimension(m=1), "meter", "m")
Expand Down
43 changes: 20 additions & 23 deletions brainunit/_unit_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@

import numpy as np

from ._base import (turn_off_unit_register,
allow_python_scalar)
from ._unit_common import (
amp,
coulomb,
Expand All @@ -69,24 +67,23 @@
'zero_celsius',
]

with turn_off_unit_register(), allow_python_scalar():
#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na)
avogadro_constant = 6.022140857e23 / mole
#: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k)
boltzmann_constant = 1.38064852e-23 * joule / kelvin
#: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0)
electric_constant = 8.854187817e-12 * farad / meter
#: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me)
electron_mass = 9.10938356e-31 * kilogram
#: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e)
elementary_charge = 1.6021766208e-19 * coulomb
#: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f)
faraday_constant = 96485.33289 * coulomb / mole
#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r)
gas_constant = 8.3144598 * joule / mole / kelvin
#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0)
magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2
#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu)
molar_mass_constant = 1 * gram / mole
#: zero degree Celsius
zero_celsius = 273.15 * kelvin
#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na)
avogadro_constant = 6.022140857e23 / mole
#: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k)
boltzmann_constant = 1.38064852e-23 * joule / kelvin
#: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0)
electric_constant = 8.854187817e-12 * farad / meter
#: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me)
electron_mass = 9.10938356e-31 * kilogram
#: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e)
elementary_charge = 1.6021766208e-19 * coulomb
#: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f)
faraday_constant = 96485.33289 * coulomb / mole
#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r)
gas_constant = 8.3144598 * joule / mole / kelvin
#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0)
magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2
#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu)
molar_mass_constant = 1 * gram / mole
#: zero degree Celsius
zero_celsius = 273.15 * kelvin
25 changes: 24 additions & 1 deletion brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import warnings

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_equal

array = np.array
import brainunit as bu

bst.environ.set(precision=64)

from brainunit._unit_common import *
Expand Down Expand Up @@ -1623,6 +1625,7 @@ def test_constants():
(constants.avogadro_constant * constants.elementary_charge).value,
)


# if __name__ == "__main__":
# test_construction()
# test_get_dimensions()
Expand Down Expand Up @@ -1659,3 +1662,23 @@ def test_constants():
# test_units_vs_quantities()
# test_all_units_list()
# test_constants()


def test_jit_array():
@jax.jit
def f1(a):
b = a * bu.siemens / bu.cm ** 2
return b

val = np.random.rand(3)
r = f1(val)
bu.math.allclose(val * bu.siemens / bu.cm ** 2, r)

@jax.jit
def f2(a):
a = a + 1. * bu.siemens / bu.cm ** 2
return a

val = np.random.rand(3) * bu.siemens / bu.cm ** 2
r = f2(val)
bu.math.allclose(val + 1 * bu.siemens / bu.cm ** 2, r)

0 comments on commit 3d25cce

Please sign in to comment.