Skip to content

Commit

Permalink
add brainunit.get_magnitude(), .magnitude attribute, and tests (#53)
Browse files Browse the repository at this point in the history
* add `brainunit.get_mantissa()`

* add `brainunit.get_magnitude()`, `.magnitude` attribute, and tests

* update
  • Loading branch information
chaoming0625 authored Sep 16, 2024
1 parent a98fd07 commit 5418cca
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 18 deletions.
74 changes: 61 additions & 13 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@
'is_unitless',
'get_dim',
'get_unit',
'get_mantissa',
'get_magnitude',
'display_in_unit',
'maybe_decimal',

# helper classes
'DIMENSIONLESS',
'UNITLESS',

# functions for checking
'check_dims',
'check_units',
Expand Down Expand Up @@ -336,10 +342,6 @@ def dim(self):
"""
return self

@property
def unit(self):
return self

# ---- REPRESENTATION ---- #
def _str_representation(self, python_code: bool = False):
"""
Expand Down Expand Up @@ -671,6 +673,37 @@ def get_unit(obj) -> Unit:
raise TypeError(f"Object of type {type(obj)} does not have a unit")


@set_module_as('brainunit')
def get_mantissa(obj):
"""
Return the mantissa of a Quantity or a number.
Parameters
----------
obj : `object`
The object to check.
Returns
-------
mantissa : `float` or `array_like`
The mantissa of the `obj`.
See Also
--------
get_dim
get_unit
"""
try:
return obj.mantissa
except AttributeError:
return obj


get_magnitude = get_mantissa



@set_module_as('brainunit')
def have_same_dim(obj1, obj2) -> bool:
"""Test if two values have the same dimensions.
Expand Down Expand Up @@ -2078,7 +2111,22 @@ def mantissa(self) -> jax.typing.ArrayLike:
"""
return self._mantissa

def update_value(self, mantissa: PyTree):
@property
def magnitude(self) -> jax.typing.ArrayLike:
"""
The magnitude of the array.
Same as :py:meth:`mantissa`.
In the scientific notation, :math:`x = a * 10^b`, the magnitude :math:`b` is the exponent
of the power of ten. For example, in the number :math:`3.14 * 10^5`, the magnitude is :math:`5`.
Returns:
The magnitude of the array.
"""
return self.mantissa

def update_mantissa(self, mantissa: PyTree):
"""
Set the mantissa of the array.
Expand Down Expand Up @@ -2433,7 +2481,7 @@ def __setitem__(self, index, value: 'Quantity' | jax.typing.ArrayLike):
# update
self_value = jnp.asarray(self._check_tracer())
self_value = self_value.at[index].set(value.mantissa)
self.update_value(self_value)
self.update_mantissa(self_value)

def scatter_add(
self,
Expand Down Expand Up @@ -2736,7 +2784,7 @@ def _binary_operation(

# update the mantissa in-place or not
if inplace:
self.update_value(r.mantissa)
self.update_mantissa(r.mantissa)
return self
else:
return r
Expand Down Expand Up @@ -2920,7 +2968,7 @@ def __rlshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike:
def __ilshift__(self, oc) -> 'Quantity':
# self <<= oc
r = self.__lshift__(oc)
self.update_value(r.mantissa)
self.update_mantissa(r.mantissa)
return self

def __rshift__(self, oc) -> 'Quantity':
Expand All @@ -2939,7 +2987,7 @@ def __rrshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike:
def __irshift__(self, oc) -> 'Quantity':
# self >>= oc
r = self.__rshift__(oc)
self.update_value(r.mantissa)
self.update_mantissa(r.mantissa)
return self

def __round__(self, ndigits: int = None) -> 'Quantity':
Expand Down Expand Up @@ -3144,7 +3192,7 @@ def reshape(self, *shape, order='C') -> 'Quantity':

def resize(self, new_shape) -> 'Quantity':
"""Change shape and size of array in-place."""
self.update_value(jnp.resize(self.mantissa, new_shape))
self.update_mantissa(jnp.resize(self.mantissa, new_shape))
return self

def sort(self, axis=-1, stable=True, order=None) -> 'Quantity':
Expand All @@ -3164,7 +3212,7 @@ def sort(self, axis=-1, stable=True, order=None) -> 'Quantity':
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
self.update_value(jnp.sort(self.mantissa, axis=axis, stable=stable, order=order))
self.update_mantissa(jnp.sort(self.mantissa, axis=axis, stable=stable, order=order))
return self

def squeeze(self, axis=None) -> 'Quantity':
Expand Down Expand Up @@ -3582,12 +3630,12 @@ def tree_unflatten(cls, unit, values) -> 'Quantity':

def cuda(self, deice=None) -> 'Quantity':
deice = jax.devices('cuda')[0] if deice is None else deice
self.update_value(jax.device_put(self.mantissa, deice))
self.update_mantissa(jax.device_put(self.mantissa, deice))
return self

def cpu(self, device=None) -> 'Quantity':
device = jax.devices('cpu')[0] if device is None else device
self.update_value(jax.device_put(self.mantissa, device))
self.update_mantissa(jax.device_put(self.mantissa, device))
return self

# dtype exchanging #
Expand Down
86 changes: 86 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pytest
from numpy.testing import assert_equal

import brainunit as u
import brainunit as bu
from brainunit._base import (
DIMENSIONLESS,
Expand Down Expand Up @@ -1439,3 +1440,88 @@ def test_str_repr():
]:
assert len(str(error))
assert len(repr(error))


class TestGetMethod(unittest.TestCase):
def test_get_dim(self):
assert u.get_dim(1) == u.DIMENSIONLESS
assert u.get_dim(1.0) == u.DIMENSIONLESS
assert u.get_dim(1 * u.mV) == u.volt.dim
assert u.get_dim(1 * u.mV / u.mV) == u.DIMENSIONLESS
assert u.get_dim(1 * u.mV / u.second) == u.volt.dim / u.second.dim
assert u.get_dim(1 * u.mV / u.second ** 2) == u.volt.dim / u.second.dim ** 2
assert u.get_dim(1 * u.mV ** 2 / u.second ** 2) == u.volt.dim ** 2 / u.second.dim ** 2

assert u.get_dim(object()) == u.DIMENSIONLESS
assert u.get_dim("string") == u.DIMENSIONLESS
assert u.get_dim([1, 2, 3]) == u.DIMENSIONLESS
assert u.get_dim(np.array([1, 2, 3])) == u.DIMENSIONLESS
assert u.get_dim(np.array([1, 2, 3]) * u.mV) == u.volt.dim

assert u.get_dim(u.mV) == u.volt.dim
assert u.get_dim(u.mV / u.mV) == u.DIMENSIONLESS
assert u.get_dim(u.mV / u.second) == u.volt.dim / u.second.dim
assert u.get_dim(u.mV / u.second ** 2) == u.volt.dim / u.second.dim ** 2
assert u.get_dim(u.mV ** 2 / u.second ** 2) == u.volt.dim ** 2 / u.second.dim ** 2

assert u.get_dim(u.mV.dim) == u.volt.dim
assert u.get_dim(u.mV.dim / u.mV.dim) == u.DIMENSIONLESS
assert u.get_dim(u.mV.dim / u.second.dim) == u.volt.dim / u.second.dim
assert u.get_dim(u.mV.dim / u.second.dim ** 2) == u.volt.dim / u.second.dim ** 2
assert u.get_dim(u.mV.dim ** 2 / u.second.dim ** 2) == u.volt.dim ** 2 / u.second.dim ** 2

def test_unit(self):
assert u.get_unit(1) == u.UNITLESS
assert u.get_unit(1.0) == u.UNITLESS
assert u.get_unit(1 * u.mV) == u.mV
assert u.get_unit(1 * u.mV / u.mV) == u.UNITLESS
assert u.get_unit(1 * u.mV / u.second) == u.mV / u.second
assert u.get_unit(1 * u.mV / u.second ** 2) == u.mV / u.second ** 2
assert u.get_unit(1 * u.mV ** 2 / u.second ** 2) == u.mV ** 2 / u.second ** 2

assert u.get_unit(object()) == u.UNITLESS
assert u.get_unit("string") == u.UNITLESS
assert u.get_unit([1, 2, 3]) == u.UNITLESS
assert u.get_unit(np.array([1, 2, 3])) == u.UNITLESS
assert u.get_unit(np.array([1, 2, 3]) * u.mV) == u.mV

assert u.get_unit(u.mV) == u.mV
assert u.get_unit(u.mV / u.mV) == u.UNITLESS
assert u.get_unit(u.mV / u.second) == u.mV / u.second
assert u.get_unit(u.mV / u.second ** 2) == u.mV / u.second ** 2
assert u.get_unit(u.mV ** 2 / u.second ** 2) == u.mV ** 2 / u.second ** 2

assert u.get_unit(u.mV.dim) == u.UNITLESS
assert u.get_unit(u.mV.dim / u.mV.dim) == u.UNITLESS
assert u.get_unit(u.mV.dim / u.second.dim) == u.UNITLESS
assert u.get_unit(u.mV.dim / u.second.dim ** 2) == u.UNITLESS
assert u.get_unit(u.mV.dim ** 2 / u.second.dim ** 2) == u.UNITLESS

def test_get_mantissa(self):
assert u.get_mantissa(1) == 1
assert u.get_mantissa(1.0) == 1.0
assert u.get_mantissa(1 * u.mV) == 1
assert u.get_mantissa(1 * u.mV / u.mV) == 1
assert u.get_mantissa(1 * u.mV / u.second) == 1
assert u.get_mantissa(1 * u.mV / u.second ** 2) == 1
assert u.get_mantissa(1 * u.mV ** 2 / u.second ** 2) == 1

obj = object()
assert u.get_mantissa(obj) == obj
assert u.get_mantissa("string") == "string"
assert u.get_mantissa([1, 2, 3]) == [1, 2, 3]
assert np.allclose(u.get_mantissa(np.array([1, 2, 3])), np.array([1, 2, 3]))
assert np.allclose(u.get_mantissa(np.array([1, 2, 3]) * u.mV), np.array([1, 2, 3]))

assert u.get_mantissa(u.mV) == u.mV
assert u.get_mantissa(u.mV / u.mV) == u.mV / u.mV
assert u.get_mantissa(u.mV / u.second) == u.mV / u.second
assert u.get_mantissa(u.mV / u.second ** 2) == u.mV / u.second ** 2
assert u.get_mantissa(u.mV ** 2 / u.second ** 2) == u.mV ** 2 / u.second ** 2

assert u.get_mantissa(u.mV.dim) == u.mV.dim
assert u.get_mantissa(u.mV.dim / u.mV.dim) == u.mV.dim / u.mV.dim
assert u.get_mantissa(u.mV.dim / u.second.dim) == u.mV.dim / u.second.dim
assert u.get_mantissa(u.mV.dim / u.second.dim ** 2) == u.mV.dim / u.second.dim ** 2
assert u.get_mantissa(u.mV.dim ** 2 / u.second.dim ** 2) == u.mV.dim ** 2 / u.second.dim ** 2

10 changes: 5 additions & 5 deletions brainunit/math/_fun_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,14 @@ def asarray(
else:
unit = leaf_unit

#
a_mantissa = treedef.unflatten([leaf.mantissa for leaf in leaves])
a_mantissa = jnp.asarray(a_mantissa, dtype=dtype, order=order)
# reconstruct mantissa
a = treedef.unflatten([leaf.mantissa for leaf in leaves])
a = jnp.asarray(a, dtype=dtype, order=order)

# returns
if unit.is_unitless:
return a_mantissa
return Quantity(a_mantissa, unit=unit)
return a
return Quantity(a, unit=unit)


array = asarray
Expand Down

0 comments on commit 5418cca

Please sign in to comment.