From 1f8b1706ba1f76715c68b51181fd4defd725734f Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 18 Jun 2024 16:59:06 +0800 Subject: [PATCH] Fix bugs and incorrect logics in ``Qauntity`` and its related math functions (#21) * 1. add `Quantity.unit` attribute 2. ``__array__``, ``__float__``, ``__int__``, and ``__index__`` function dimension and value checking * fix the wrong logic of several numpy apis * update unit tests * fix `Quantity.take()` function * 1. rename `get_unit_for_display()` to `get_dim_for_display()` 2. rename `array_with_unit()` with `array_with_dim()` --- brainunit/_base.py | 158 +++++-- brainunit/_base_test.py | 28 ++ brainunit/_unit_test.py | 438 +++++------------- .../math/_compat_numpy_funcs_keep_unit.py | 17 +- brainunit/math/_compat_numpy_funcs_logic.py | 4 +- 5 files changed, 263 insertions(+), 382 deletions(-) create mode 100644 brainunit/_base_test.py diff --git a/brainunit/_base.py b/brainunit/_base.py index e7fb14e..2f7c94d 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -104,6 +104,7 @@ def _short_str(arr): Return a short string representation of an array, suitable for use in error messages. """ + arr = arr.value if isinstance(arr, Quantity) else arr arr = np.asanyarray(arr) old_printoptions = jnp.get_printoptions() jnp.set_printoptions(edgeitems=2, threshold=5) @@ -112,7 +113,7 @@ def _short_str(arr): return arr_string -def get_unit_for_display(d): +def get_dim_for_display(d): """ Return a string representation of an appropriate unscaled unit or ``'1'`` for a dimensionless array. @@ -181,6 +182,13 @@ def get_unit_for_display(d): "cd": 6, } +# Length (meter) +# Mass (kilogram) +# Time (second) +# Current (ampere) +# Temperature (Kelvin) +# Amount of substance (mole) +# Luminous intensity (candela) _ilabel = ["m", "kg", "s", "A", "K", "mol", "cd"] # The same labels with the names used for constructing them in Python code @@ -453,6 +461,8 @@ def get_or_create_dimension(*args, **kwds): '''The dimensionless unit, used for quantities without a unit.''' DIMENSIONLESS = Dimension((0, 0, 0, 0, 0, 0, 0)) + +'''The dictionary of all existing Dimension objects.''' _dimensions = {(0, 0, 0, 0, 0, 0, 0): DIMENSIONLESS} @@ -492,16 +502,16 @@ def __str__(self): if len(self.dims) == 0: pass elif len(self.dims) == 1: - s += f" (unit is {get_unit_for_display(self.dims[0])}" + s += f" (unit is {get_dim_for_display(self.dims[0])}" elif len(self.dims) == 2: d1, d2 = self.dims s += ( - f" (units are {get_unit_for_display(d1)} and {get_unit_for_display(d2)}" + f" (units are {get_dim_for_display(d1)} and {get_dim_for_display(d2)}" ) else: s += ( " (units are" - f" {' '.join([f'({get_unit_for_display(d)})' for d in self.dims])}" + f" {' '.join([f'({get_dim_for_display(d)})' for d in self.dims])}" ) if len(self.dims): s += ")." @@ -510,7 +520,7 @@ def __str__(self): def get_dim(obj) -> Dimension: """ - Return the unit of any object that has them. + Return the dimension of any object that has them. Slightly more general than `Array.dimensions` because it will return `DIMENSIONLESS` if the object is of number type but not a `Array` @@ -741,9 +751,9 @@ def in_best_unit(x, precision=None): return x.repr_in_unit(u, precision=precision) -def array_with_unit( +def array_with_dim( floatval, - unit: Dimension, + dim: Dimension, dtype: jax.typing.DTypeLike = None ) -> 'Quantity': """ @@ -757,8 +767,8 @@ def array_with_unit( ---------- floatval : `float` The floating point value of the array. - unit: Dimension - The unit dimensions of the array. + dim: Dimension + The dim dimensions of the array. dtype: `dtype`, optional The data type of the array. @@ -770,10 +780,10 @@ def array_with_unit( Examples -------- >>> from brainunit import * - >>> array_with_unit(0.001, volt.dim) + >>> array_with_dim(0.001, volt.dim) 1. * mvolt """ - return Quantity(floatval, dim=get_or_create_dimension(unit._dims), dtype=dtype) + return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype) def is_unitless(obj) -> bool: @@ -1054,6 +1064,34 @@ def dim(self, *args): raise NotImplementedError("Cannot set the dimension of a Quantity object directly," "Please create a new Quantity object with the value you want.") + @property + def unit(self) -> 'Unit': + return Unit(1., self.dim, register=False) + + @unit.setter + def unit(self, *args): + # Do not support setting the unit directly + raise NotImplementedError("Cannot set the unit of a Quantity object directly," + "Please create a new Quantity object with the unit you want.") + + def to_value(self, unit: 'Unit') -> jax.Array | numbers.Number: + """ + Convert the value of the array to a new unit. + + Examples:: + + >>> a = jax.numpy.array([1, 2, 3]) * mV + >>> a.to_value(volt) + array([0.001, 0.002, 0.003]) + + Args: + unit: The new unit to convert the value of the array to. + + Returns: + The value of the array in the new unit. + """ + return self.value / unit.value + @staticmethod def with_units(value, *args, **keywords): """ @@ -1506,9 +1544,7 @@ def __radd__(self, oc): def __iadd__(self, oc): # a += b - r = self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True) - self.update_value(r.value) - return self + return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True) def __sub__(self, oc): return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-") @@ -1518,9 +1554,7 @@ def __rsub__(self, oc): def __isub__(self, oc): # a -= b - r = self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True) - self.update_value(r.value) - return self + return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True) def __mul__(self, oc): r = self._binary_operation(oc, operator.mul, operator.mul) @@ -1731,7 +1765,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity': return Quantity(self.value.__round__(ndigits), dim=self.dim) def __reduce__(self): - return array_with_unit, (self.value, self.dim, None) + return array_with_dim, (self.value, self.dim, None) # ----------------------- # # NumPy methods # @@ -1963,10 +1997,19 @@ def take( ) -> 'Quantity': """Return an array formed from the elements of a at the given indices.""" if isinstance(fill_value, Quantity): + fail_for_dimension_mismatch(self, fill_value, "take") fill_value = fill_value.value - return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode, - unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, - fill_value=fill_value), dim=self.dim) + elif fill_value is not None: + if not self.is_unitless: + raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}") + return Quantity( + jnp.take(self.value, + indices=indices, axis=axis, mode=mode, + unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, + fill_value=fill_value), + dim=self.dim + ) def tolist(self): """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. @@ -2226,9 +2269,11 @@ def view(self, *args, dtype=None) -> 'Quantity': # NumPy support # ------------------ - def to_numpy(self, - dtype: Optional[jax.typing.DTypeLike] = None, - unit: Optional['Unit'] = None) -> np.ndarray: + def to_numpy( + self, + unit: Optional['Unit'] = None, + dtype: Optional[jax.typing.DTypeLike] = None, + ) -> np.ndarray: """ Remove the unit and convert to ``numpy.ndarray``. @@ -2240,14 +2285,19 @@ def to_numpy(self, The numpy.ndarray. """ if unit is None: + assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to " + f"NumPy arrays when 'unit' is not provided. But got {self}") return np.asarray(self.value, dtype=dtype) else: + fail_for_dimension_mismatch(self, unit, "to_numpy") assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" return np.asarray(self / unit, dtype=dtype) - def to_jax(self, - dtype: Optional[jax.typing.DTypeLike] = None, - unit: Optional['Unit'] = None) -> jax.Array: + def to_jax( + self, + unit: Optional['Unit'] = None, + dtype: Optional[jax.typing.DTypeLike] = None, + ) -> jax.Array: """ Remove the unit and convert to ``jax.Array``. @@ -2259,20 +2309,50 @@ def to_jax(self, The jax.Array. """ if unit is None: + assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to " + f"JAX arrays when 'unit' is not provided. But got {self}") return jnp.asarray(self.value, dtype=dtype) else: + fail_for_dimension_mismatch(self, unit, "to_jax") assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" return jnp.asarray(self / unit, dtype=dtype) def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" - return np.asarray(self.value, dtype=dtype) + if self.dim == DIMENSIONLESS: + return np.asarray(self.value, dtype=dtype) + else: + raise TypeError( + f"only dimensionless quantities can be " + f"converted to NumPy arrays. But got {self}" + ) def __float__(self): - return self.value.__float__() + if self.dim == DIMENSIONLESS and self.ndim == 0: + return float(self.value) + else: + raise TypeError( + "only dimensionless scalar quantities can be " + f"converted to Python scalars. But got {self}" + ) + + def __int__(self): + if self.dim == DIMENSIONLESS and self.ndim == 0: + return int(self.value) + else: + raise TypeError( + "only dimensionless scalar quantities can be " + f"converted to Python scalars. But got {self}" + ) def __index__(self): - return operator.index(self.value) + if self.dim == DIMENSIONLESS: + return operator.index(self.value) + else: + raise TypeError( + "only dimensionless quantities can be " + f"converted to a Python index. But got {self}" + ) # ---------------------- # PyTorch compatibility @@ -2518,6 +2598,7 @@ def __init__( dispname: str = None, iscompound: bool = None, dtype: jax.typing.DTypeLike = None, + register: bool = True, ): if dim is None: dim = DIMENSIONLESS @@ -2543,7 +2624,7 @@ def __init__( super().__init__(value, dtype=dtype, dim=dim) - if _auto_register_unit: + if _auto_register_unit and register: register_new_unit(self) @staticmethod @@ -2783,10 +2864,11 @@ def add(self, u: Unit): if isinstance(u.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)): self.units_for_dimensions[u.dim][1.] = u else: - self.units_for_dimensions[u.dim][float(u)] = u + self.units_for_dimensions[u.dim][float(u.value)] = u def __getitem__(self, x): - """Returns the best unit for array x + """ + Returns the best unit for array x The algorithm is to consider the value: @@ -3005,9 +3087,7 @@ def new_f(*args, **kwds): v = Quantity(v) except TypeError: if have_same_unit(au[n], 1): - raise TypeError( - f"Argument {n} is not a unitless value/array." - ) + raise TypeError(f"Argument {n} is not a unitless value/array.") else: raise TypeError( f"Argument '{n}' is not a array, " @@ -3053,9 +3133,9 @@ def new_f(*args, **kwds): f"the argument '{k}' to have the same " f"units as argument '{au[k]}', but " f"argument '{k}' has " - f"unit {get_unit_for_display(d1)}, " + f"unit {get_dim_for_display(d1)}, " f"while argument '{au[k]}' " - f"has unit {get_unit_for_display(d2)}." + f"has unit {get_dim_for_display(d2)}." ) raise DimensionMismatchError(error_message) elif not have_same_unit(newkeyset[k], au[k]): @@ -3087,7 +3167,7 @@ def new_f(*args, **kwds): ) raise TypeError(error_message) elif not have_same_unit(result, expected_result): - unit = get_unit_for_display(expected_result) + unit = get_dim_for_display(expected_result) error_message = ( "The return value of function " f"'{f.__name__}' was expected to have " diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py new file mode 100644 index 0000000..a39fd38 --- /dev/null +++ b/brainunit/_base_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +import brainunit as bu + + +class TestQuantity(unittest.TestCase): + def test_dim(self): + a = [1, 2.] * bu.ms + + with self.assertRaises(NotImplementedError): + a.dim = bu.mV.dim + + diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index f7b171a..992cd91 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -17,16 +17,16 @@ import warnings import brainstate as bst + +bst.environ.set(precision=64) + import jax import jax.numpy as jnp import numpy as np import pytest from numpy.testing import assert_equal - import brainunit as bu -bst.environ.set(precision=64) - from brainunit._unit_common import * from brainunit._base import ( DIMENSIONLESS, @@ -72,7 +72,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): - values = jnp.asarray(values) + values = values.value if isinstance(values, Quantity) else np.asarray(values) if isinstance(q, Quantity): assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_dim(q)}) ({get_dim(unit)})" if not jnp.allclose(q.value, values): @@ -481,19 +481,21 @@ def test_format_quantity(): with bst.environ.context(precision=64): q = 0.5 * ms assert f"{q}" == f"{q!s}" == str(q) - assert f"{q:g}" == f"{float(q)}" + print(f"{q:g}") + assert f"{q:g}" == f"{float(q / bu.second)}" def test_slicing(): # Slicing and indexing, setting items - Array = np.reshape(np.arange(6), (2, 3)) * mV - assert_allclose(Array[:].value, Array.value) - assert_allclose(Array[0].value, (np.asarray(Array)[0] * volt).value) - assert_allclose(Array[0:1].value, (np.asarray(Array)[0:1] * volt).value) - assert_allclose(Array[0, 1].value, (np.asarray(Array)[0, 1] * volt).value) - assert_allclose(Array[0:1, 1:].value, (np.asarray(Array)[0:1, 1:] * volt).value) + a = np.reshape(np.arange(6), (2, 3)) + q = a * mV + assert_allclose(q[:].value, q.value) + assert_allclose(q[0].value, (a[0] * volt).value) + assert_allclose(q[0:1].value, (a[0:1] * volt).value) + assert_allclose(q[0, 1].value, (a[0, 1] * volt).value) + assert_allclose(q[0:1, 1:].value, (a[0:1, 1:] * volt).value) bool_matrix = np.array([[True, False, False], [False, False, True]]) - assert_allclose(Array[bool_matrix].value, (np.asarray(Array)[bool_matrix] * volt).value) + assert_allclose(q[bool_matrix].value, (a[bool_matrix] * volt).value) def test_setting(): @@ -526,31 +528,31 @@ def test_multiplication_division(): for q in quantities: # Scalars and array scalars - assert_quantity(q / 3, np.asarray(q) / 3, volt) - assert_quantity(3 / q, 3 / np.asarray(q), 1 / volt) - assert_quantity(q * 3, np.asarray(q) * 3, volt) - assert_quantity(3 * q, 3 * np.asarray(q), volt) - assert_quantity(q / np.float64(3), np.asarray(q) / 3, volt) - assert_quantity(np.float64(3) / q, 3 / np.asarray(q), 1 / volt) - assert_quantity(q * np.float64(3), np.asarray(q) * 3, volt) - assert_quantity(np.float64(3) * q, 3 * np.asarray(q), volt) - assert_quantity(q / jnp.array(3), np.asarray(q) / 3, volt) - assert_quantity(np.array(3) / q, 3 / np.asarray(q), 1 / volt) - assert_quantity(q * jnp.array(3), np.asarray(q) * 3, volt) - assert_quantity(np.array(3) * q, 3 * np.asarray(q), volt) + assert_quantity(q / 3, q.value / 3, volt) + assert_quantity(3 / q, 3 / q.value, 1 / volt) + assert_quantity(q * 3, q.value * 3, volt) + assert_quantity(3 * q, 3 * q.value, volt) + assert_quantity(q / np.float64(3), q.value / 3, volt) + assert_quantity(np.float64(3) / q, 3 / q.value, 1 / volt) + assert_quantity(q * np.float64(3), q.value * 3, volt) + assert_quantity(np.float64(3) * q, 3 * q.value, volt) + assert_quantity(q / jnp.array(3), q.value / 3, volt) + assert_quantity(np.array(3) / q, 3 / q.value, 1 / volt) + assert_quantity(q * jnp.array(3), q.value * 3, volt) + assert_quantity(np.array(3) * q, 3 * q.value, volt) # (unitless) arrays - assert_quantity(q / np.array([3]), np.asarray(q) / 3, volt) - assert_quantity(np.array([3]) / q, 3 / np.asarray(q), 1 / volt) - assert_quantity(q * np.array([3]), np.asarray(q) * 3, volt) - assert_quantity(np.array([3]) * q, 3 * np.asarray(q), volt) + assert_quantity(q / np.array([3]), q.value / 3, volt) + assert_quantity(np.array([3]) / q, 3 / q.value, 1 / volt) + assert_quantity(q * np.array([3]), q.value * 3, volt) + assert_quantity(np.array([3]) * q, 3 * q.value, volt) # arrays with units - assert_quantity(q / q, np.asarray(q) / np.asarray(q), 1) - assert_quantity(q * q, np.asarray(q) ** 2, volt ** 2) - assert_quantity(q / q2, np.asarray(q) / np.asarray(q2), volt / second) - assert_quantity(q2 / q, np.asarray(q2) / np.asarray(q), second / volt) - assert_quantity(q * q2, np.asarray(q) * np.asarray(q2), volt * second) + assert_quantity(q / q, q.value / q.value, 1) + assert_quantity(q * q, q.value ** 2, volt ** 2) + assert_quantity(q / q2, q.value / q2.value, volt / second) + assert_quantity(q2 / q, q2.value / q.value, second / volt) + assert_quantity(q * q2, q.value * q2.value, volt * second) # # using unsupported objects should fail # with pytest.raises(TypeError): @@ -569,12 +571,12 @@ def test_addition_subtraction(): for q in quantities: # arrays with units - assert_quantity(q + q, np.asarray(q) + np.asarray(q), volt) + assert_quantity(q + q, q.value + q.value, volt) assert_quantity(q - q, 0, volt) - assert_quantity(q + q2, np.asarray(q) + np.asarray(q2), volt) - assert_quantity(q2 + q, np.asarray(q2) + np.asarray(q), volt) - assert_quantity(q - q2, np.asarray(q) - np.asarray(q2), volt) - assert_quantity(q2 - q, np.asarray(q2) - np.asarray(q), volt) + assert_quantity(q + q2, q.value + q2.value, volt) + assert_quantity(q2 + q, q2.value + q.value, volt) + assert_quantity(q - q2, q.value - q2.value, volt) + assert_quantity(q2 - q, q2.value - q.value, volt) # mismatching units with pytest.raises(DimensionMismatchError): @@ -623,15 +625,15 @@ def test_addition_subtraction(): np.array([5], dtype=np.float64) - q # Check that operations with 0 work - assert_quantity(q + 0, np.asarray(q), volt) - assert_quantity(0 + q, np.asarray(q), volt) - assert_quantity(q - 0, np.asarray(q), volt) + assert_quantity(q + 0, q.value, volt) + assert_quantity(0 + q, q.value, volt) + assert_quantity(q - 0, q.value, volt) # Doesn't support 0 - Quantity - # assert_quantity(0 - q, -np.asarray(q), volt) - assert_quantity(q + np.float64(0), np.asarray(q), volt) - assert_quantity(np.float64(0) + q, np.asarray(q), volt) - assert_quantity(q - np.float64(0), np.asarray(q), volt) - # assert_quantity(np.float64(0) - q, -np.asarray(q), volt) + # assert_quantity(0 - q, -q.value, volt) + assert_quantity(q + np.float64(0), q.value, volt) + assert_quantity(np.float64(0) + q, q.value, volt) + assert_quantity(q - np.float64(0), q.value, volt) + # assert_quantity(np.float64(0) - q, -q.value, volt) # # using unsupported objects should fail # with pytest.raises(TypeError): @@ -669,16 +671,16 @@ def assert_operations_work(a, b): # Test equivalent numpy functions numpy_funcs = [ - np.add, - np.subtract, - np.less, - np.less_equal, - np.greater, - np.greater_equal, - np.equal, - np.not_equal, - np.maximum, - np.minimum, + bu.math.add, + bu.math.subtract, + bu.math.less, + bu.math.less_equal, + bu.math.greater, + bu.math.greater_equal, + bu.math.equal, + bu.math.not_equal, + bu.math.maximum, + bu.math.minimum, ] for numpy_func in numpy_funcs: numpy_func(a, b) @@ -784,17 +786,15 @@ def test_power(): """ Test raising quantities to a power. """ - values = [2 * kilogram, np.array([2]) * kilogram, np.array([1, 2]) * kilogram] - for value in values: - assert_quantity(value ** 3, np.asarray(value) ** 3, kilogram ** 3) + arrs = [2 * kilogram, np.array([2]) * kilogram, np.array([1, 2]) * kilogram] + for a in arrs: + assert_quantity(a ** 3, a.value ** 3, kilogram ** 3) # Test raising to a dimensionless Array - assert_quantity( - value ** (3 * volt / volt), np.asarray(value) ** 3, kilogram ** 3 - ) + assert_quantity(a ** (3 * volt / volt), a.value ** 3, kilogram ** 3) with pytest.raises(DimensionMismatchError): - value ** (2 * volt) + a ** (2 * volt) with pytest.raises(TypeError): - value ** np.array([2, 3]) + a ** np.array([2, 3]) def test_inplace_operations(): @@ -889,16 +889,15 @@ def test_unit_discarding_functions(): """ Test functions that discard units. """ - from brainunit.math import ones_like, zeros_like values = [3 * mV, np.array([1, 2]) * mV, np.arange(12).reshape(3, 4) * mV] - for value in values: - assert_equal(np.sign(value.value), np.sign(np.asarray(value.value))) - assert_equal(zeros_like(value), np.zeros_like(np.asarray(value.value))) - assert_equal(ones_like(value), np.ones_like(np.asarray(value.value))) + for a in values: + assert_equal(np.sign(a.value), np.sign(np.asarray(a.value))) + assert_equal(bu.math.zeros_like(a).value, np.zeros_like(np.asarray(a.value))) + assert_equal(bu.math.ones_like(a).value, np.ones_like(np.asarray(a.value))) # Calling non-zero on a 0d array is deprecated, don't test it: - if value.ndim > 0: - assert_equal(np.nonzero(value.value), np.nonzero(np.asarray(value.value))) + if a.ndim > 0: + assert_equal(np.nonzero(a.value), np.nonzero(np.asarray(a.value))) def test_unitsafe_functions(): @@ -985,45 +984,58 @@ def test_special_case_numpy_functions(): assert_allclose(ravel(quadratic_matrix).value, quadratic_matrix.ravel().value) # Check that function gives the same result as on unitless arrays assert_allclose( - np.asarray(ravel(quadratic_matrix).value), ravel(np.asarray(quadratic_matrix)) + np.asarray(ravel(quadratic_matrix).value), + ravel(np.asarray(quadratic_matrix.value)) ) # Check that the function gives the same results as the original numpy # function assert_allclose( - np.ravel(np.asarray(quadratic_matrix.value)), ravel(np.asarray(quadratic_matrix.value)) + np.ravel(np.asarray(quadratic_matrix.value)), + ravel(np.asarray(quadratic_matrix.value)) ) # Do the same checks for diagonal, trace and dot assert_allclose(diagonal(quadratic_matrix).value, quadratic_matrix.diagonal().value) assert_allclose( - np.asarray(diagonal(quadratic_matrix).value), diagonal(np.asarray(quadratic_matrix.value)) + np.asarray(diagonal(quadratic_matrix).value), + diagonal(np.asarray(quadratic_matrix.value)) ) assert_allclose( np.diagonal(np.asarray(quadratic_matrix.value)), diagonal(np.asarray(quadratic_matrix.value)), ) - assert_allclose(trace(quadratic_matrix).value, quadratic_matrix.trace().value) assert_allclose( - np.asarray(trace(quadratic_matrix).value), trace(np.asarray(quadratic_matrix.value)) + trace(quadratic_matrix).value, + quadratic_matrix.trace().value ) assert_allclose( - np.trace(np.asarray(quadratic_matrix.value)), trace(np.asarray(quadratic_matrix.value)) + np.asarray(trace(quadratic_matrix).value), + trace(np.asarray(quadratic_matrix.value)) + ) + assert_allclose( + np.trace(np.asarray(quadratic_matrix.value)), + trace(np.asarray(quadratic_matrix.value)) ) assert_allclose( - dot(quadratic_matrix, quadratic_matrix).value, quadratic_matrix.dot(quadratic_matrix).value + dot(quadratic_matrix, quadratic_matrix).value, + quadratic_matrix.dot(quadratic_matrix).value ) assert_allclose( np.asarray(dot(quadratic_matrix, quadratic_matrix).value), - dot(np.asarray(quadratic_matrix.value), np.asarray(quadratic_matrix.value)), + dot(np.asarray(quadratic_matrix.value), + np.asarray(quadratic_matrix.value)), ) assert_allclose( - np.dot(np.asarray(quadratic_matrix.value), np.asarray(quadratic_matrix.value)), - dot(np.asarray(quadratic_matrix.value), np.asarray(quadratic_matrix.value)), + np.dot(np.asarray(quadratic_matrix.value), + np.asarray(quadratic_matrix.value)), + dot(np.asarray(quadratic_matrix.value), + np.asarray(quadratic_matrix.value)), ) assert_allclose( - np.asarray(quadratic_matrix.prod().value), np.asarray(quadratic_matrix.value).prod() + np.asarray(quadratic_matrix.prod().value), + np.asarray(quadratic_matrix.value).prod() ) assert_allclose( np.asarray(quadratic_matrix.prod(axis=0).value), @@ -1035,10 +1047,12 @@ def test_special_case_numpy_functions(): assert have_same_unit(quadratic_matrix, trace(quadratic_matrix)) assert have_same_unit(quadratic_matrix, diagonal(quadratic_matrix)) assert have_same_unit( - quadratic_matrix[0] ** 2, dot(quadratic_matrix, quadratic_matrix) + quadratic_matrix[0] ** 2, + dot(quadratic_matrix, quadratic_matrix) ) assert have_same_unit( - quadratic_matrix.prod(axis=0), quadratic_matrix[0] ** quadratic_matrix.shape[0] + quadratic_matrix.prod(axis=0), + quadratic_matrix[0] ** quadratic_matrix.shape[0] ) # check the where function @@ -1051,15 +1065,16 @@ def test_special_case_numpy_functions(): # dimensionless Array assert_allclose( - np.where(cond, ar1, ar2), np.asarray(where(cond, ar1 * mV / mV, ar2 * mV / mV)) + np.where(cond, ar1, ar2), + np.asarray(where(cond, ar1 * mV / mV, ar2 * mV / mV)) ) # Array with dimensions ar1 = ar1 * mV ar2 = ar2 * mV assert_allclose( - np.where(cond, np.asarray(ar1), np.asarray(ar2)), - np.asarray(where(cond, ar1, ar2)), + np.where(cond, ar1.value, ar2.value), + np.asarray(where(cond, ar1, ar2).value), ) # Check some error cases @@ -1083,8 +1098,7 @@ def test_special_case_numpy_functions(): # Check cumprod a = np.arange(1, 10) * mV / mV assert_allclose(a.cumprod(), np.asarray(a).cumprod()) - # with pytest.raises(TypeError): - # (np.arange(1, 5) * mV).cumprod() + (np.arange(1, 5) * mV).cumprod() # Functions that should not change units @@ -1162,215 +1176,6 @@ def test_numpy_functions_indices(): ) -# Do not support numpy functions -# def test_numpy_functions_dimensionless(): -# """ -# Test that numpy functions that should work on dimensionless quantities only -# work dimensionless arrays and return the correct result. -# """ -# unitless_values = [3, np.array([-4, 3, -1, 2]), np.ones((3, 3))] -# unit_values = [3 * mV, np.array([-4, 3, -1, 2]) * mV, np.ones((3, 3)) * mV] -# with warnings.catch_warnings(): -# # ignore division by 0 warnings -# warnings.simplefilter("ignore", RuntimeWarning) -# for value in unitless_values: -# for ufunc in ufuncs_dimensionless: -# result_unitless = eval(f"np.{ufunc}(value)") -# result_array = eval(f"np.{ufunc}(np.array(value))") -# assert isinstance( -# result_unitless, (np.ndarray, np.number) -# ) and not isinstance(result_unitless, Quantity) -# assert_equal(result_unitless, result_array) -# for ufunc in ufuncs_dimensionless_twoargs: -# result_unitless = eval(f"np.{ufunc}(value, value)") -# result_array = eval(f"np.{ufunc}(np.array(value), np.array(value))") -# assert isinstance( -# result_unitless, (np.ndarray, np.number) -# ) and not isinstance(result_unitless, Quantity) -# assert_equal(result_unitless, result_array) -# -# for value, unitless_value in zip(unit_values, unitless_values): -# for ufunc in ufuncs_dimensionless: -# with pytest.raises(DimensionMismatchError): -# eval(f"np.{ufunc}(value)", globals(), {"value": value}) -# for ufunc in ufuncs_dimensionless_twoargs: -# with pytest.raises(DimensionMismatchError): -# eval( -# f"np.{ufunc}(value1, value2)", -# globals(), -# {"value1": value, "value2": unitless_value}, -# ) -# with pytest.raises(DimensionMismatchError): -# eval( -# f"np.{ufunc}(value2, value1)", -# globals(), -# {"value1": value, "value2": unitless_value}, -# ) -# with pytest.raises(DimensionMismatchError): -# eval(f"np.{ufunc}(value, value)", globals(), {"value": value}) - - -# Do not support numpy functions -# def test_numpy_functions_change_dimensions(): -# """ -# Test some numpy functions that change the dimensions of the Array. -# """ -# unit_values = [np.array([1, 2]) * mV, np.ones((3, 3)) * 2 * mV] -# for value in unit_values: -# assert_quantity(np.var(value), np.var(np.array(value)), volt ** 2) -# assert_quantity(np.square(value), np.square(np.array(value)), volt ** 2) -# assert_quantity(np.sqrt(value), np.sqrt(np.array(value)), volt ** 0.5) -# assert_quantity( -# np.reciprocal(value), np.reciprocal(np.array(value)), 1.0 / volt -# ) - - -# Do not support numpy functions -# def test_numpy_functions_matmul(): -# """ -# Check support for matmul and the ``@`` operator. -# """ -# no_units_eye = np.eye(3) -# with_units_eye = no_units_eye * Mohm -# matrix_no_units = np.arange(9).reshape((3, 3)) -# matrix_units = matrix_no_units * nA -# -# # First operand with units -# assert_allclose((no_units_eye @ matrix_units).value, matrix_units.value) -# assert have_same_unit(no_units_eye @ matrix_units, matrix_units) -# assert_allclose(np.matmul(no_units_eye, matrix_units.value), matrix_units.value) -# assert have_same_unit(np.matmul(no_units_eye, matrix_units.value), matrix_units.value) -# -# # Second operand with units -# assert_allclose((with_units_eye @ matrix_no_units).value, (matrix_no_units * Mohm).value) -# assert have_same_unit( -# with_units_eye @ matrix_no_units, matrix_no_units * Mohm -# ) -# assert_allclose(np.matmul(with_units_eye.value, matrix_no_units), (matrix_no_units * Mohm).value) -# assert have_same_unit( -# np.matmul(with_units_eye, matrix_no_units), matrix_no_units * Mohm -# ) -# -# # Both operands with units -# assert_allclose( -# (with_units_eye @ matrix_units).value, (no_units_eye @ matrix_no_units * nA * Mohm).value -# ) -# assert have_same_unit(with_units_eye @ matrix_units, nA * Mohm) -# assert_allclose( -# np.matmul(with_units_eye.value, matrix_units.value), -# (np.matmul(no_units_eye, matrix_no_units) * nA * Mohm).value, -# ) -# assert have_same_unit(np.matmul(with_units_eye, matrix_units), nA * Mohm) - - -# def test_numpy_functions_typeerror(): -# """ -# Assures that certain numpy functions raise a TypeError when called on -# quantities. -# """ -# unitless_values = [ -# 3 * mV / mV, -# np.array([1, 2]) * mV / mV, -# np.ones((3, 3)) * mV / mV, -# ] -# unit_values = [3 * mV, np.array([1, 2]) * mV, np.ones((3, 3)) * mV] -# for value in unitless_values + unit_values: -# for ufunc in ufuncs_integers: -# if ufunc == "invert": -# # only takes one argument -# with pytest.raises(TypeError): -# eval(f"np.{ufunc}(value)", globals(), {"value": value}) -# else: -# with pytest.raises(TypeError): -# eval(f"np.{ufunc}(value, value)", globals(), {"value": value}) - - -# Doesn't support logical functions -# -# def test_numpy_functions_logical(): -# """ -# Assure that logical numpy functions work on all quantities and return -# unitless boolean arrays. -# """ -# unit_values1 = [3 * mV, np.array([1, 2]) * mV, np.ones((3, 3)) * mV] -# unit_values2 = [3 * second, np.array([1, 2]) * second, np.ones((3, 3)) * second] -# for ufunc in ufuncs_logical: -# for value1, value2 in zip(unit_values1, unit_values2): -# try: -# # one argument -# result_units = eval(f"np.{ufunc}(value1)") -# result_array = eval(f"np.{ufunc}(np.array(value1))") -# except (ValueError, TypeError): -# # two arguments -# result_units = eval(f"np.{ufunc}(value1, value2)") -# result_array = eval(f"np.{ufunc}(np.array(value1), np.array(value2))") -# # assert that comparing to a string results in "NotImplemented" or an error -# try: -# result = eval(f'np.{ufunc}(value1, "a string")') -# assert result == NotImplemented -# except (ValueError, TypeError): -# pass # raised on numpy >= 0.10 -# try: -# result = eval(f'np.{ufunc}("a string", value1)') -# assert result == NotImplemented -# except (ValueError, TypeError): -# pass # raised on numpy >= 0.10 -# assert not isinstance(result_units, Quantity) -# assert_equal(result_units, result_array) - - -# -# def test_arange_linspace(): -# # For dimensionless values, the unit-safe functions should give the same results -# assert_equal(brian2.arange(5), np.arange(5)) -# assert_equal(brian2.arange(1, 5), np.arange(1, 5)) -# assert_equal(brian2.arange(10, step=2), np.arange(10, step=2)) -# assert_equal(brian2.arange(0, 5, 0.5), np.arange(0, 5, 0.5)) -# assert_equal(brian2.linspace(0, 1), np.linspace(0, 1)) -# assert_equal(brian2.linspace(0, 1, 10), np.linspace(0, 1, 10)) -# -# # Make sure units are checked -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1 * mV, 5) -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1 * mV, 5 * mV) -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1, 5 * mV) -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1 * mV, 5 * ms) -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1 * mV, 5 * mV, step=1 * ms) -# with pytest.raises(DimensionMismatchError): -# brian2.arange(1 * ms, 5 * mV) -# -# # Check correct functioning with units -# assert_quantity( -# brian2.arange(5 * mV, step=1 * mV), float(mV) * np.arange(5, step=1), mV -# ) -# assert_quantity( -# brian2.arange(1 * mV, 5 * mV, 1 * mV), float(mV) * np.arange(1, 5, 1), mV -# ) -# assert_quantity(brian2.linspace(1 * mV, 2 * mV), float(mV) * np.linspace(1, 2), mV) -# -# # Check errors for arange with incorrect numbers of arguments/duplicate arguments -# with pytest.raises(TypeError): -# brian2.arange() -# with pytest.raises(TypeError): -# brian2.arange(0, 5, 1, 0) -# with pytest.raises(TypeError): -# brian2.arange(0, stop=1) -# with pytest.raises(TypeError): -# brian2.arange(0, 5, stop=1) -# with pytest.raises(TypeError): -# brian2.arange(0, 5, start=1) -# with pytest.raises(TypeError): -# brian2.arange(0, 5, 1, start=1) -# with pytest.raises(TypeError): -# brian2.arange(0, 5, 1, stop=2) -# with pytest.raises(TypeError): -# brian2.arange(0, 5, 1, step=2) - - def test_list(): """ Test converting to and from a list. @@ -1462,7 +1267,7 @@ def test_get_basic_unit(): unit = get_basic_unit(unit) assert isinstance(unit, Unit) assert unit == expected_unit - assert float(unit) == 1.0 + assert float(unit.value) == 1.0 def test_get_best_unit(): @@ -1495,7 +1300,7 @@ def test_switching_off_unit_checks(): with turn_off_unit_checking(): # Now it should work - assert np.asarray(x + y) == np.array(8) + assert (x + y).value == np.array(8) assert have_same_unit(x, y) assert x.has_same_unit(y) @@ -1536,39 +1341,6 @@ def test_deepcopy(): assert d["x"] == 1 * second -# Doesn't support copy -# -# def test_inplace_on_scalars(): -# # We want "copy semantics" for in-place operations on scalar quantities -# # in the same way as for Python scalars -# for scalar in [3 * mV, 3 * mV / mV]: -# scalar_reference = scalar -# scalar_copy = Quantity(scalar, copy=True) -# scalar += scalar_copy -# assert_equal(scalar_copy, scalar_reference) -# scalar *= 1.5 -# assert_equal(scalar_copy, scalar_reference) -# scalar /= 2 -# assert_equal(scalar_copy, scalar_reference) -# -# # also check that it worked correctly for the scalar itself -# assert_allclose(scalar, (scalar_copy + scalar_copy) * 1.5 / 2) -# -# # For arrays, it should use reference semantics -# for vector in [[3] * mV, [3] * mV / mV]: -# vector_reference = vector -# vector_copy = Quantity(vector, copy=True) -# vector += vector_copy -# assert_equal(vector, vector_reference) -# vector *= 1.5 -# assert_equal(vector, vector_reference) -# vector /= 2 -# assert_equal(vector, vector_reference) -# -# # also check that it worked correctly for the vector itself -# assert_allclose(vector, (vector_copy + vector_copy) * 1.5 / 2) - - def test_units_vs_quantities(): # Unit objects should stay Unit objects under certain operations # (important e.g. in the unit definition of Equations, where only units but diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 62217f0..6a79970 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -967,19 +967,22 @@ def modf( def funcs_keep_unit_binary( func, - x1, x2, + x1, + x2, *args, - check_same_dim=True, **kwargs ): if isinstance(x1, Quantity) and isinstance(x2, Quantity): - if check_same_dim: - fail_for_dimension_mismatch(x1, x2, func.__name__) + fail_for_dimension_mismatch(x1, x2, func.__name__) return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) + elif isinstance(x1, Quantity): + assert x1.is_unitless, f'Expected unitless array when x2 is not Quantity, while got {x1}' + return func(x1.value, x2, *args, **kwargs) + elif isinstance(x2, Quantity): + assert x2.is_unitless, f'Expected unitless array when x1 is not Quantity, while got {x2}' + return func(x1, x2.value, *args, **kwargs) else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + return func(x1, x2, *args, **kwargs) @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py index a3cd2cb..af282f5 100644 --- a/brainunit/math/_compat_numpy_funcs_logic.py +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -175,10 +175,8 @@ def logic_func_binary(func, x, y, *args, **kwargs): elif isinstance(y, Quantity): assert y.is_unitless, f'Expected unitless array when x is not Quantity, while got {y}' return func(x, y.value, *args, **kwargs) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + return func(x, y, *args, **kwargs) @set_module_as('brainunit.math')