Skip to content

Commit

Permalink
Merge branch 'main' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 18, 2024
2 parents 9797532 + 1f8b170 commit 820bcc5
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 382 deletions.
158 changes: 119 additions & 39 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _short_str(arr):
Return a short string representation of an array, suitable for use in
error messages.
"""
arr = arr.value if isinstance(arr, Quantity) else arr
arr = np.asanyarray(arr)
old_printoptions = jnp.get_printoptions()
jnp.set_printoptions(edgeitems=2, threshold=5)
Expand All @@ -112,7 +113,7 @@ def _short_str(arr):
return arr_string


def get_unit_for_display(d):
def get_dim_for_display(d):
"""
Return a string representation of an appropriate unscaled unit or ``'1'``
for a dimensionless array.
Expand Down Expand Up @@ -181,6 +182,13 @@ def get_unit_for_display(d):
"cd": 6,
}

# Length (meter)
# Mass (kilogram)
# Time (second)
# Current (ampere)
# Temperature (Kelvin)
# Amount of substance (mole)
# Luminous intensity (candela)
_ilabel = ["m", "kg", "s", "A", "K", "mol", "cd"]

# The same labels with the names used for constructing them in Python code
Expand Down Expand Up @@ -453,6 +461,8 @@ def get_or_create_dimension(*args, **kwds):

'''The dimensionless unit, used for quantities without a unit.'''
DIMENSIONLESS = Dimension((0, 0, 0, 0, 0, 0, 0))

'''The dictionary of all existing Dimension objects.'''
_dimensions = {(0, 0, 0, 0, 0, 0, 0): DIMENSIONLESS}


Expand Down Expand Up @@ -492,16 +502,16 @@ def __str__(self):
if len(self.dims) == 0:
pass
elif len(self.dims) == 1:
s += f" (unit is {get_unit_for_display(self.dims[0])}"
s += f" (unit is {get_dim_for_display(self.dims[0])}"
elif len(self.dims) == 2:
d1, d2 = self.dims
s += (
f" (units are {get_unit_for_display(d1)} and {get_unit_for_display(d2)}"
f" (units are {get_dim_for_display(d1)} and {get_dim_for_display(d2)}"
)
else:
s += (
" (units are"
f" {' '.join([f'({get_unit_for_display(d)})' for d in self.dims])}"
f" {' '.join([f'({get_dim_for_display(d)})' for d in self.dims])}"
)
if len(self.dims):
s += ")."
Expand All @@ -510,7 +520,7 @@ def __str__(self):

def get_dim(obj) -> Dimension:
"""
Return the unit of any object that has them.
Return the dimension of any object that has them.
Slightly more general than `Array.dimensions` because it will
return `DIMENSIONLESS` if the object is of number type but not a `Array`
Expand Down Expand Up @@ -741,9 +751,9 @@ def in_best_unit(x, precision=None):
return x.repr_in_unit(u, precision=precision)


def array_with_unit(
def array_with_dim(
floatval,
unit: Dimension,
dim: Dimension,
dtype: jax.typing.DTypeLike = None
) -> 'Quantity':
"""
Expand All @@ -757,8 +767,8 @@ def array_with_unit(
----------
floatval : `float`
The floating point value of the array.
unit: Dimension
The unit dimensions of the array.
dim: Dimension
The dim dimensions of the array.
dtype: `dtype`, optional
The data type of the array.
Expand All @@ -770,10 +780,10 @@ def array_with_unit(
Examples
--------
>>> from brainunit import *
>>> array_with_unit(0.001, volt.dim)
>>> array_with_dim(0.001, volt.dim)
1. * mvolt
"""
return Quantity(floatval, dim=get_or_create_dimension(unit._dims), dtype=dtype)
return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype)


def is_unitless(obj) -> bool:
Expand Down Expand Up @@ -1054,6 +1064,34 @@ def dim(self, *args):
raise NotImplementedError("Cannot set the dimension of a Quantity object directly,"
"Please create a new Quantity object with the value you want.")

@property
def unit(self) -> 'Unit':
return Unit(1., self.dim, register=False)

@unit.setter
def unit(self, *args):
# Do not support setting the unit directly
raise NotImplementedError("Cannot set the unit of a Quantity object directly,"
"Please create a new Quantity object with the unit you want.")

def to_value(self, unit: 'Unit') -> jax.Array | numbers.Number:
"""
Convert the value of the array to a new unit.
Examples::
>>> a = jax.numpy.array([1, 2, 3]) * mV
>>> a.to_value(volt)
array([0.001, 0.002, 0.003])
Args:
unit: The new unit to convert the value of the array to.
Returns:
The value of the array in the new unit.
"""
return self.value / unit.value

@staticmethod
def with_units(value, *args, **keywords):
"""
Expand Down Expand Up @@ -1506,9 +1544,7 @@ def __radd__(self, oc):

def __iadd__(self, oc):
# a += b
r = self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)
self.update_value(r.value)
return self
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)

def __sub__(self, oc):
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-")
Expand All @@ -1518,9 +1554,7 @@ def __rsub__(self, oc):

def __isub__(self, oc):
# a -= b
r = self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)
self.update_value(r.value)
return self
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)

def __mul__(self, oc):
r = self._binary_operation(oc, operator.mul, operator.mul)
Expand Down Expand Up @@ -1731,7 +1765,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity':
return Quantity(self.value.__round__(ndigits), dim=self.dim)

def __reduce__(self):
return array_with_unit, (self.value, self.dim, None)
return array_with_dim, (self.value, self.dim, None)

# ----------------------- #
# NumPy methods #
Expand Down Expand Up @@ -1963,10 +1997,19 @@ def take(
) -> 'Quantity':
"""Return an array formed from the elements of a at the given indices."""
if isinstance(fill_value, Quantity):
fail_for_dimension_mismatch(self, fill_value, "take")
fill_value = fill_value.value
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value), dim=self.dim)
elif fill_value is not None:
if not self.is_unitless:
raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}")
return Quantity(
jnp.take(self.value,
indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value),
dim=self.dim
)

def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Expand Down Expand Up @@ -2226,9 +2269,11 @@ def view(self, *args, dtype=None) -> 'Quantity':
# NumPy support
# ------------------

def to_numpy(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> np.ndarray:
def to_numpy(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> np.ndarray:
"""
Remove the unit and convert to ``numpy.ndarray``.
Expand All @@ -2240,14 +2285,19 @@ def to_numpy(self,
The numpy.ndarray.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"NumPy arrays when 'unit' is not provided. But got {self}")
return np.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_numpy")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return np.asarray(self / unit, dtype=dtype)

def to_jax(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> jax.Array:
def to_jax(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> jax.Array:
"""
Remove the unit and convert to ``jax.Array``.
Expand All @@ -2259,20 +2309,50 @@ def to_jax(self,
The jax.Array.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"JAX arrays when 'unit' is not provided. But got {self}")
return jnp.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_jax")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return jnp.asarray(self / unit, dtype=dtype)

def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray:
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value, dtype=dtype)
if self.dim == DIMENSIONLESS:
return np.asarray(self.value, dtype=dtype)
else:
raise TypeError(
f"only dimensionless quantities can be "
f"converted to NumPy arrays. But got {self}"
)

def __float__(self):
return self.value.__float__()
if self.dim == DIMENSIONLESS and self.ndim == 0:
return float(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)

def __int__(self):
if self.dim == DIMENSIONLESS and self.ndim == 0:
return int(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)

def __index__(self):
return operator.index(self.value)
if self.dim == DIMENSIONLESS:
return operator.index(self.value)
else:
raise TypeError(
"only dimensionless quantities can be "
f"converted to a Python index. But got {self}"
)

# ----------------------
# PyTorch compatibility
Expand Down Expand Up @@ -2518,6 +2598,7 @@ def __init__(
dispname: str = None,
iscompound: bool = None,
dtype: jax.typing.DTypeLike = None,
register: bool = True,
):
if dim is None:
dim = DIMENSIONLESS
Expand All @@ -2543,7 +2624,7 @@ def __init__(

super().__init__(value, dtype=dtype, dim=dim)

if _auto_register_unit:
if _auto_register_unit and register:
register_new_unit(self)

@staticmethod
Expand Down Expand Up @@ -2783,10 +2864,11 @@ def add(self, u: Unit):
if isinstance(u.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
self.units_for_dimensions[u.dim][1.] = u
else:
self.units_for_dimensions[u.dim][float(u)] = u
self.units_for_dimensions[u.dim][float(u.value)] = u

def __getitem__(self, x):
"""Returns the best unit for array x
"""
Returns the best unit for array x
The algorithm is to consider the value:
Expand Down Expand Up @@ -3005,9 +3087,7 @@ def new_f(*args, **kwds):
v = Quantity(v)
except TypeError:
if have_same_unit(au[n], 1):
raise TypeError(
f"Argument {n} is not a unitless value/array."
)
raise TypeError(f"Argument {n} is not a unitless value/array.")
else:
raise TypeError(
f"Argument '{n}' is not a array, "
Expand Down Expand Up @@ -3053,9 +3133,9 @@ def new_f(*args, **kwds):
f"the argument '{k}' to have the same "
f"units as argument '{au[k]}', but "
f"argument '{k}' has "
f"unit {get_unit_for_display(d1)}, "
f"unit {get_dim_for_display(d1)}, "
f"while argument '{au[k]}' "
f"has unit {get_unit_for_display(d2)}."
f"has unit {get_dim_for_display(d2)}."
)
raise DimensionMismatchError(error_message)
elif not have_same_unit(newkeyset[k], au[k]):
Expand Down Expand Up @@ -3087,7 +3167,7 @@ def new_f(*args, **kwds):
)
raise TypeError(error_message)
elif not have_same_unit(result, expected_result):
unit = get_unit_for_display(expected_result)
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
Expand Down
28 changes: 28 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest

import brainunit as bu


class TestQuantity(unittest.TestCase):
def test_dim(self):
a = [1, 2.] * bu.ms

with self.assertRaises(NotImplementedError):
a.dim = bu.mV.dim


Loading

0 comments on commit 820bcc5

Please sign in to comment.