Skip to content

Commit

Permalink
add maybe_decimal() (#51)
Browse files Browse the repository at this point in the history
* add `maybe_decimal()`

* fix bug

* fix tests

* fix tests
  • Loading branch information
chaoming0625 authored Sep 9, 2024
1 parent 0c9526a commit da89eda
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 110 deletions.
139 changes: 78 additions & 61 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
'get_dim',
'get_unit',
'display_in_unit',
'maybe_decimal',

# functions for checking
'check_dims',
Expand Down Expand Up @@ -601,12 +602,12 @@ def __str__(self):
elif len(self.units) == 2:
d1, d2 = self.units
s += (
f" (units are {d1} and {d2}"
f" (units are {d1} and {d2} "
)
else:
s += (
" (units are"
f" {' '.join([f'({d})' for d in self.units])}"
f" {' '.join([f'({d})' for d in self.units])} "
)
if len(self.units):
s += ")."
Expand Down Expand Up @@ -918,6 +919,35 @@ def display_in_unit(
return x.repr_in_unit(precision=precision, python_code=python_code)


@set_module_as('brainunit')
def maybe_decimal(
val: Union['Quantity', jax.typing.ArrayLike],
unit: Optional['Unit'] = None
) -> Union[jax.Array, 'Quantity']:
"""
Convert a quantity to a decimal number if it is a dimensionless quantity.
Parameters
----------
val : {`Array`, array-like, number}
The value to convert.
unit: `Unit`, optional
The base unit maybe used to convert the value to.
Returns
-------
decimal : `float`
The value as a decimal number.
"""
valq = _to_quantity(val)
if valq.dim.is_dimensionless:
return valq.to_decimal()
if unit is not None:
return valq.to_decimal(unit)
else:
return val


@set_module_as('brainunit')
def unit_scale_align_to_first(*args) -> List['Quantity']:
"""
Expand Down Expand Up @@ -1098,7 +1128,7 @@ def _wrap_function_change_unit(func, unit_fun):

def f(x, *args, **kwds): # pylint: disable=C0111
assert isinstance(x, Quantity), "Only Quantity objects can be passed to this function"
return remove_unitless(Quantity(func(x.mantissa, *args, **kwds), unit=unit_fun(x.unit, x.unit)))
return maybe_decimal(Quantity(func(x.mantissa, *args, **kwds), unit=unit_fun(x.unit, x.unit)))

f._arg_units = [None]
f._return_unit = unit_fun
Expand Down Expand Up @@ -1131,13 +1161,6 @@ def f(x, *args, **kwds): # pylint: disable=C0111
return f


def remove_unitless(q: 'Quantity') -> jax.typing.ArrayLike | 'Quantity':
if q.is_unitless:
return q.mantissa
else:
return q


def _assert_same_base(u1, u2):
assert u1.has_same_base(u2), (f"Currently, we only support units have different bases. "
f"But we got {u1.base} != {u1.base}.")
Expand Down Expand Up @@ -2138,14 +2161,17 @@ def to_decimal(self, unit: Unit = UNITLESS) -> jax.typing.ArrayLike:
else:
return self._mantissa

def in_unit(self, unit: Unit):
def in_unit(self, unit: Unit, err_msg: str = None) -> 'Quantity':
"""
Convert the given :py:class:`Quantity` into the given unit.
"""
assert isinstance(unit, Unit), f"Expected a Unit, but got {unit}."
if not unit.has_same_dim(self.unit):
raise UnitMismatchError(f"Cannot convert to a unit with different dimensions.", self.unit, unit)
if err_msg is None:
raise UnitMismatchError(f"Cannot convert to a unit with different dimensions.", self.unit, unit)
else:
raise UnitMismatchError(err_msg)
if unit.has_same_scale(self.unit):
u = Quantity(self._mantissa, unit=unit)
else:
Expand Down Expand Up @@ -2241,27 +2267,27 @@ def repr_in_unit(
>>> x.repr_in_unit(mV, 3)
'25.123 mV'
"""
value = jnp.asarray(self._mantissa)
value = jnp.asarray(self.mantissa)
if _is_tracer(value):
# in the JIT mode
s = str(value)
else:
if value.shape == ():
s = jnp.array_str(jnp.array([value]), precision=precision)
s = np.array_str(np.array([value]), precision=precision)
s = s.replace("[", "").replace("]", "").strip()
else:
if value.size > 100:
if python_code:
s = jnp.array_repr(value, precision=precision)[:100]
s = np.array_repr(value, precision=precision)[:100]
s += "..."
else:
s = jnp.array_str(value, precision=precision)[:100]
s = np.array_str(value, precision=precision)[:100]
s += "..."
else:
if python_code:
s = jnp.array_repr(value, precision=precision)
s = np.array_repr(value, precision=precision)
else:
s = jnp.array_str(value, precision=precision)
s = np.array_str(value, precision=precision)

if not self.unit.is_unitless:
if python_code:
Expand Down Expand Up @@ -2625,12 +2651,12 @@ def __invert__(self) -> 'Quantity':

def _comparison(self, other: Any, operator_str: str, operation: Callable):
other = _to_quantity(other)
message = "Cannot perform comparison {value1} %s {value2}, units do not match" % operator_str
fail_for_unit_mismatch(self, other, message, value1=self, value2=other)
oth_value = other._mantissa
if not other.unit.has_same_scale(self.unit):
oth_value = oth_value * (other.unit.value / self.unit.value)
return operation(self._mantissa, oth_value)
try:
other_value = other.in_unit(self.unit).mantissa
except UnitMismatchError as e:
raise UnitMismatchError(f"Cannot compare {self} {operator_str} {other}, "
f"since units do not match: {self.unit} != {other.unit}") from e
return operation(self.mantissa, other_value)

def __eq__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, "==", operator.eq)
Expand Down Expand Up @@ -2687,28 +2713,23 @@ def _binary_operation(
other = _to_quantity(other)

# format the unit and mantissa of "other"
other_unit = None
other_value = other._mantissa
if fail_for_mismatch:
if inplace:
message = "Cannot calculate ... %s {value}, units do not match" % operator_str
_, other_unit = fail_for_unit_mismatch(self, other, message, value=other)
else:
message = "Cannot calculate {value1} %s {value2}, units do not match" % operator_str
_, other_unit = fail_for_unit_mismatch(self, other, message, value1=self, value2=other)
if not other_unit.has_same_scale(self.unit):
other_value = other_value * (other_unit.value / self.unit.value)
if other_unit is None:
other_unit = get_unit(other)
other = other.in_unit(self.unit,
err_msg=f"Cannot calculate \n"
f"{self} {operator_str} {other}"
f"because units do not match: {self.unit} != {other.unit}")
other_value = other.mantissa
other_unit = other.unit

# calculate the new unit and mantissa
new_unit = unit_operation(self.unit, other_unit)
result = value_operation(self._mantissa, other_value)
r = Quantity(result, unit=new_unit)
r = Quantity(
value_operation(self.mantissa, other_value),
unit=unit_operation(self.unit, other_unit)
)

# update the mantissa in-place or not
if inplace:
self.update_value(r._mantissa)
self.update_value(r.mantissa)
return self
else:
return r
Expand All @@ -2735,7 +2756,7 @@ def __isub__(self, oc):

def __mul__(self, oc):
r = self._binary_operation(oc, operator.mul, operator.mul)
return remove_unitless(r)
return maybe_decimal(r)

def __rmul__(self, oc):
return self.__mul__(oc)
Expand All @@ -2747,7 +2768,7 @@ def __imul__(self, oc):
def __div__(self, oc):
# self / oc
r = self._binary_operation(oc, operator.truediv, operator.truediv)
return remove_unitless(r)
return maybe_decimal(r)

def __idiv__(self, oc):
raise NotImplementedError("In-place division is not supported, since it changes the unit.")
Expand All @@ -2761,7 +2782,7 @@ def __rdiv__(self, oc):
# division with swapped arguments
rdiv = lambda a, b: operator.truediv(b, a)
r = self._binary_operation(oc, rdiv, rdiv)
return remove_unitless(r)
return maybe_decimal(r)

def __rtruediv__(self, oc):
# oc / self
Expand All @@ -2774,14 +2795,14 @@ def __itruediv__(self, oc):
def __floordiv__(self, oc):
# self // oc
r = self._binary_operation(oc, operator.floordiv, operator.truediv)
return remove_unitless(r)
return maybe_decimal(r)

def __rfloordiv__(self, oc):
# oc // self
rdiv = lambda a, b: operator.truediv(b, a)
rfloordiv = lambda a, b: operator.floordiv(b, a)
r = self._binary_operation(oc, rfloordiv, rdiv)
return remove_unitless(r)
return maybe_decimal(r)

def __ifloordiv__(self, oc):
# a //= b
Expand All @@ -2790,13 +2811,13 @@ def __ifloordiv__(self, oc):
def __mod__(self, oc):
# self % oc
r = self._binary_operation(oc, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%")
return remove_unitless(r)
return maybe_decimal(r)

def __rmod__(self, oc):
# oc % self
oc = _to_quantity(oc)
r = oc._binary_operation(self, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%")
return remove_unitless(r)
return maybe_decimal(r)

def __imod__(self, oc):
raise NotImplementedError("In-place mod is not supported, since it changes the unit.")
Expand All @@ -2809,12 +2830,12 @@ def __rdivmod__(self, oc):

def __matmul__(self, oc):
r = self._binary_operation(oc, operator.matmul, operator.mul, operator_str="@")
return remove_unitless(r)
return maybe_decimal(r)

def __rmatmul__(self, oc):
oc = _to_quantity(oc)
r = oc._binary_operation(self, operator.matmul, operator.mul, operator_str="@")
return remove_unitless(r)
return maybe_decimal(r)

def __imatmul__(self, oc):
# a @= b
Expand All @@ -2827,7 +2848,7 @@ def __pow__(self, oc):
assert oc.is_unitless, f"Cannot calculate {self} ** {oc}, the exponent has to be dimensionless"
oc = oc.mantissa
r = Quantity(jnp.array(self.mantissa) ** oc, unit=self.unit ** oc)
return remove_unitless(r)
return maybe_decimal(r)

def __rpow__(self, oc):
# oc ** self
Expand Down Expand Up @@ -2882,7 +2903,7 @@ def __lshift__(self, oc) -> 'Quantity':
assert oc.is_unitless, "The shift amount must be dimensionless"
oc = oc._mantissa
r = Quantity(self._mantissa << oc, unit=self.unit)
return remove_unitless(r)
return maybe_decimal(r)

def __rlshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike:
# oc << self
Expand All @@ -2901,7 +2922,7 @@ def __rshift__(self, oc) -> 'Quantity':
assert oc.is_unitless, "The shift amount must be dimensionless"
oc = oc._mantissa
r = Quantity(self._mantissa >> oc, unit=self.unit)
return remove_unitless(r)
return maybe_decimal(r)

def __rrshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike:
# oc >> self
Expand Down Expand Up @@ -3016,7 +3037,7 @@ def copy(self) -> 'Quantity':
def dot(self, b) -> 'Quantity':
"""Dot product of two arrays."""
r = self._binary_operation(b, jnp.dot, operator.mul, operator_str="@")
return remove_unitless(r)
return maybe_decimal(r)

def fill(self, value: Quantity) -> 'Quantity':
"""Fill the array with a scalar mantissa."""
Expand Down Expand Up @@ -3048,7 +3069,7 @@ def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is n
if dim_exponent.size > 1:
dim_exponent = dim_exponent[-1]
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return remove_unitless(r)
return maybe_decimal(r)

def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is not None
"""Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones."""
Expand All @@ -3058,15 +3079,15 @@ def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis i
if dim_exponent.size > 1:
dim_exponent = dim_exponent[-1]
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return remove_unitless(r)
return maybe_decimal(r)

def cumprod(self, *args, **kwds): # TODO: check error when axis is not None
prod_res = jnp.cumprod(self._mantissa, *args, **kwds)
dim_exponent = jnp.ones_like(self._mantissa).cumsum(*args, **kwds)
if dim_exponent.size > 1:
dim_exponent = dim_exponent[-1]
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return remove_unitless(r)
return maybe_decimal(r)

def nancumprod(self, *args, **kwds): # TODO: check error when axis is not None
prod_res = jnp.nancumprod(self._mantissa, *args, **kwds)
Expand All @@ -3075,7 +3096,7 @@ def nancumprod(self, *args, **kwds): # TODO: check error when axis is not None
if dim_exponent.size > 1:
dim_exponent = dim_exponent[-1]
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return remove_unitless(r)
return maybe_decimal(r)

def put(self, indices, values) -> 'Quantity':
"""Replaces specified elements of an array with given values.
Expand Down Expand Up @@ -3628,23 +3649,19 @@ def set(
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None,
fill_value: StaticScalar | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] = y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:`indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
"""
values = Quantity(values).in_unit(self.unit).mantissa
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(self.unit).mantissa.item()
return Quantity(
self.mantissa_at[self.index].set(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
fill_value=fill_value
),
unit=self.unit
)
Expand Down
Loading

0 comments on commit da89eda

Please sign in to comment.