Skip to content

Commit

Permalink
Enable automatic unit register and python scalar is allowed to be ``Q…
Browse files Browse the repository at this point in the history
…uantity.value`` by default (#19)

* enable automatic unit register and allow python scalar as Quantitu value

* fix test

* fix ``Quantity.__reduce__``

* fix dtype and shape error
  • Loading branch information
chaoming0625 authored Jun 14, 2024
1 parent 0e54af2 commit 97f8114
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 72 deletions.
113 changes: 63 additions & 50 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@
_all_slice = slice(None, None, None)
_unit_checking = True
_allow_python_scalar_value = False
_auto_register_unit = True


@contextmanager
def turn_off_auto_unit_register():
try:
global _auto_register_unit
_auto_register_unit = False
yield
finally:
_auto_register_unit = True


@contextmanager
Expand Down Expand Up @@ -945,7 +956,8 @@ def __init__(
):
scale, dim = _get_dim(dim, unit)

if isinstance(value, numbers.Number) and _allow_python_scalar_value:
# always allow python scalar
if isinstance(value, numbers.Number):
self._dim = dim
self._value = (value if scale is None else (value * scale))
return
Expand Down Expand Up @@ -1021,11 +1033,11 @@ def update_value(self, value):
else:
value = jnp.asarray(value, dtype=self.dtype)
# check
if value.shape != self_value.shape:
raise ValueError(f"The shape of the original data is {self_value.shape}, "
if value.shape != jnp.shape(self_value):
raise ValueError(f"The shape of the original data is {jnp.shape(self_value)}, "
f"while we got {value.shape}.")
if value.dtype != self_value.dtype:
raise ValueError(f"The dtype of the original data is {self_value.dtype}, "
if value.dtype != jax.dtypes.result_type(self_value):
raise ValueError(f"The dtype of the original data is {jax.dtypes.result_type(self_value)}, "
f"while we got {value.dtype}.")
self._value = value

Expand Down Expand Up @@ -1174,22 +1186,25 @@ def repr_in_unit(
"""
fail_for_dimension_mismatch(self, u, 'Non-matching unit for method "in_unit"')
value = jnp.asarray(self.value / u.value)
if value.shape == ():
s = jnp.array_str(jnp.array([value]), precision=precision)
s = s.replace("[", "").replace("]", "").strip()
if isinstance(value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
s = str(value)
else:
if value.size > 100:
if python_code:
s = jnp.array_repr(value, precision=precision)[:100]
s += "..."
else:
s = jnp.array_str(value, precision=precision)[:100]
s += "..."
if value.shape == ():
s = jnp.array_str(jnp.array([value]), precision=precision)
s = s.replace("[", "").replace("]", "").strip()
else:
if python_code:
s = jnp.array_repr(value, precision=precision)
if value.size > 100:
if python_code:
s = jnp.array_repr(value, precision=precision)[:100]
s += "..."
else:
s = jnp.array_str(value, precision=precision)[:100]
s += "..."
else:
s = jnp.array_str(value, precision=precision)
if python_code:
s = jnp.array_repr(value, precision=precision)
else:
s = jnp.array_str(value, precision=precision)

if not u.is_unitless:
if isinstance(u, Unit):
Expand Down Expand Up @@ -1292,7 +1307,7 @@ def size(self) -> int:

@property
def T(self) -> 'Quantity':
return Quantity(self.value.T, dim=self.dim)
return Quantity(jnp.asarray(self.value).T, dim=self.dim)

@property
def isreal(self) -> jax.Array:
Expand Down Expand Up @@ -1328,8 +1343,6 @@ def __repr__(self) -> str:
return self.repr_in_best_unit(python_code=True)

def __str__(self) -> str:
if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return f'{self.value} * {Quantity(1, dim=self.dim)}'
return self.repr_in_best_unit()

def __format__(self, format_spec: str) -> str:
Expand Down Expand Up @@ -1718,7 +1731,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity':
return Quantity(self.value.__round__(ndigits), dim=self.dim)

def __reduce__(self):
return array_with_unit, (self.value, self.dim, self.value.dtype)
return array_with_unit, (self.value, self.dim, None)

# ----------------------- #
# NumPy methods #
Expand Down Expand Up @@ -1756,7 +1769,7 @@ def astype(self, dtype) -> 'Quantity':
if dtype is None:
return Quantity(self.value, dim=self.dim)
else:
return Quantity(self.value.astype(dtype), dim=self.dim)
return Quantity(jnp.astype(self.value, dtype), dim=self.dim)

def clip(self, min: Quantity = None, max: Quantity = None, *args, **kwds) -> 'Quantity':
"""Return an array whose values are limited to [min, max]. One of max or min must be given."""
Expand Down Expand Up @@ -1802,11 +1815,10 @@ def flatten(self) -> 'Quantity':

def item(self, *args) -> 'Quantity':
"""Copy an element of an array to a standard Python scalar and return it."""
with allow_python_scalar():
if isinstance(self.value, jax.Array):
return Quantity(self.value.item(*args), dim=self.dim)
else:
return Quantity(self.value, dim=self.dim)
if isinstance(self.value, jax.Array):
return Quantity(self.value.item(*args), dim=self.dim)
else:
return Quantity(self.value, dim=self.dim)

def prod(self, *args, **kwds) -> 'Quantity':
"""Return the product of the array elements over the given axis."""
Expand Down Expand Up @@ -1991,8 +2003,7 @@ def top_replace(s):
if isinstance(self.value, jax.Array):
return replace_with_array(self.value.tolist(), self.dim)
else:
with allow_python_scalar():
return Quantity(self.value, dim=self.dim)
return Quantity(self.value, dim=self.dim)

def transpose(self, *axes) -> 'Quantity':
"""Returns a view of the array with axes transposed.
Expand Down Expand Up @@ -2087,7 +2098,7 @@ def view(self, *args, dtype=None) -> 'Quantity':
Example::
>>> import brainstate
>>> import brainstate, brainunit
>>> x = brainstate.random.randn(4, 4)
>>> x.size
[4, 4]
Expand All @@ -2107,7 +2118,7 @@ def view(self, *args, dtype=None) -> 'Quantity':
>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory
>>> c.size
[1, 3, 2, 4]
>>> brainstate.math.equal(b, c)
>>> brainunit.math.equal(b, c)
False
Expand Down Expand Up @@ -2346,8 +2357,7 @@ def clamp(
def clone(self) -> 'Quantity':
if isinstance(self.value, jax.Array):
return self.copy()
with allow_python_scalar():
return type(self)(self.value, dim=self.dim)
return type(self)(self.value, dim=self.dim)

def tree_flatten(self) -> Tuple[jax.Array | numbers.Number, Any]:
"""
Expand Down Expand Up @@ -2533,6 +2543,9 @@ def __init__(

super().__init__(value, dtype=dtype, dim=dim)

if _auto_register_unit:
register_new_unit(self)

@staticmethod
def create(unit: Dimension, name: str, dispname: str, scale: int = 0):
"""
Expand Down Expand Up @@ -2763,13 +2776,14 @@ class UnitRegistry:
__module__ = "brainunit"

def __init__(self):
self.units = collections.OrderedDict()
self.units_for_dimensions = collections.defaultdict(dict)

def add(self, u):
def add(self, u: Unit):
"""Add a unit to the registry"""
self.units[repr(u)] = u
self.units_for_dimensions[u.dim][float(u)] = u
if isinstance(u.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
self.units_for_dimensions[u.dim][1.] = u
else:
self.units_for_dimensions[u.dim][float(u)] = u

def __getitem__(self, x):
"""Returns the best unit for array x
Expand All @@ -2787,8 +2801,10 @@ def __getitem__(self, x):
if len(matching) == 0:
raise KeyError("Unit not found in registry.")

matching_values = jnp.array(list(matching.keys()))
print_opts = jnp.get_printoptions()
matching_values = np.array(list(matching.keys()))
if isinstance(x.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return matching[1.0]
print_opts = np.get_printoptions()
edgeitems, threshold = print_opts["edgeitems"], print_opts["threshold"]
if x.size > threshold:
# Only care about optimizing the units for the values that will
Expand All @@ -2801,19 +2817,16 @@ def __getitem__(self, x):
slices.append((slice(0, edgeitems), slice(-edgeitems, None)))
else:
slices.append((slice(None),))
x_flat = jnp.hstack(
[jnp.array(x[use_slices].flatten().value) for use_slices in itertools.product(*slices)]
)
x_flat = np.hstack([np.array(x[use_slices].flatten().value)
for use_slices in itertools.product(*slices)])
else:
x_flat = jnp.array(x.value).flatten()
floatreps = jnp.tile(jnp.abs(x_flat), (len(matching), 1)).T / matching_values
x_flat = np.array(x.value).flatten()
floatreps = np.tile(np.abs(x_flat), (len(matching), 1)).T / matching_values
# ignore zeros, they are well represented in any unit
floatreps = floatreps.at[floatreps == 0].set(jnp.nan)
# floatreps[floatreps == 0] = jnp.nan
if jnp.all(jnp.isnan(floatreps)):
floatreps[floatreps == 0] = np.nan
if np.all(np.isnan(floatreps)):
return matching[1.0] # all zeros, use the base unit

deviations = jnp.nansum((jnp.log10(floatreps) - 1) ** 2, axis=0)
deviations = np.nansum((np.log10(floatreps) - 1) ** 2, axis=0)
return list(matching.values())[deviations.argmin()]


Expand Down
4 changes: 2 additions & 2 deletions brainunit/_unit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
additional_unit_register,
get_or_create_dimension,
standard_unit_register,
allow_python_scalar
turn_off_auto_unit_register,
)

__all__ = [
Expand Down Expand Up @@ -2096,7 +2096,7 @@
"celsius" # Dummy object raising an error
]

with allow_python_scalar():
with turn_off_auto_unit_register():
#### FUNDAMENTAL UNITS
metre = Unit.create(get_or_create_dimension(m=1), "metre", "m")
meter = Unit.create(get_or_create_dimension(m=1), "meter", "m")
Expand Down
43 changes: 23 additions & 20 deletions brainunit/_unit_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
meter,
mole,
newton,
turn_off_auto_unit_register,
)

__all__ = [
Expand All @@ -67,23 +68,25 @@
'zero_celsius',
]

#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na)
avogadro_constant = 6.022140857e23 / mole
#: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k)
boltzmann_constant = 1.38064852e-23 * joule / kelvin
#: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0)
electric_constant = 8.854187817e-12 * farad / meter
#: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me)
electron_mass = 9.10938356e-31 * kilogram
#: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e)
elementary_charge = 1.6021766208e-19 * coulomb
#: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f)
faraday_constant = 96485.33289 * coulomb / mole
#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r)
gas_constant = 8.3144598 * joule / mole / kelvin
#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0)
magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2
#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu)
molar_mass_constant = 1 * gram / mole
#: zero degree Celsius
zero_celsius = 273.15 * kelvin

with turn_off_auto_unit_register():
#: Avogadro constant (http://physics.nist.gov/cgi-bin/cuu/Value?na)
avogadro_constant = 6.022140857e23 / mole
#: Boltzmann constant (physics.nist.gov/cgi-bin/cuu/Value?k)
boltzmann_constant = 1.38064852e-23 * joule / kelvin
#: electric constant (http://physics.nist.gov/cgi-bin/cuu/Value?ep0)
electric_constant = 8.854187817e-12 * farad / meter
#: Electron rest mass (physics.nist.gov/cgi-bin/cuu/Value?me)
electron_mass = 9.10938356e-31 * kilogram
#: Elementary charge (physics.nist.gov/cgi-bin/cuu/Value?e)
elementary_charge = 1.6021766208e-19 * coulomb
#: Faraday constant (http://physics.nist.gov/cgi-bin/cuu/Value?f)
faraday_constant = 96485.33289 * coulomb / mole
#: gas constant (http://physics.nist.gov/cgi-bin/cuu/Value?r)
gas_constant = 8.3144598 * joule / mole / kelvin
#: Magnetic constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu0)
magnetic_constant = 4 * np.pi * 1e-7 * newton / amp ** 2
#: Molar mass constant (http://physics.nist.gov/cgi-bin/cuu/Value?mu)
molar_mass_constant = 1 * gram / mole
#: zero degree Celsius
zero_celsius = 273.15 * kelvin
13 changes: 13 additions & 0 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,7 @@ def test_jit_array():
@jax.jit
def f1(a):
b = a * bu.siemens / bu.cm ** 2
print(b)
return b

val = np.random.rand(3)
Expand All @@ -1682,3 +1683,15 @@ def f2(a):
val = np.random.rand(3) * bu.siemens / bu.cm ** 2
r = f2(val)
bu.math.allclose(val + 1 * bu.siemens / bu.cm ** 2, r)


def test_jit_array2():
a = 2.0 * (bu.farad / bu.metre ** 2)
print(a)

@jax.jit
def f(b):
print(b)
return b

f(a)

0 comments on commit 97f8114

Please sign in to comment.