From cf27d3687da3fd6f0fdd4fd1a75a03819eec2fb9 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 14 Jun 2024 15:37:16 +0800 Subject: [PATCH] updates --- brainunit/_base.py | 35 ++++++++-------- brainunit/_unit_test.py | 20 ++++----- .../math/_compat_numpy_array_manipulation.py | 2 +- .../math/_compat_numpy_funcs_keep_unit.py | 34 +++++++++------ brainunit/math/_compat_numpy_funcs_window.py | 1 - brainunit/math/_compat_numpy_misc.py | 42 +++++++------------ 6 files changed, 65 insertions(+), 69 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 153eade..52adc58 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -36,7 +36,7 @@ 'DIMENSIONLESS', 'DimensionMismatchError', 'get_or_create_dimension', - 'get_unit', + 'get_dim', 'get_basic_unit', 'is_unitless', 'have_same_unit', @@ -119,7 +119,7 @@ def get_unit_for_display(d): if (isinstance(d, int) and d == 1) or d is DIMENSIONLESS: return "1" else: - return str(get_unit(d)) + return str(get_dim(d)) # SI dimensions (see table at the top of the file) and various descriptions, @@ -497,7 +497,7 @@ def __str__(self): return s -def get_unit(obj) -> Dimension: +def get_dim(obj) -> Dimension: """ Return the unit of any object that has them. @@ -551,8 +551,8 @@ def have_same_unit(obj1, obj2) -> bool: # should only add a small amount of unnecessary computation for cases in # which this function returns False which very likely leads to a # DimensionMismatchError anyway. - dim1 = get_unit(obj1) - dim2 = get_unit(obj2) + dim1 = get_dim(obj1) + dim2 = get_dim(obj2) return (dim1 is dim2) or (dim1 == dim2) or dim1 is None or dim2 is None @@ -598,11 +598,11 @@ def fail_for_dimension_mismatch( if not _unit_checking: return None, None - dim1 = get_unit(obj1) + dim1 = get_dim(obj1) if obj2 is None: dim2 = DIMENSIONLESS else: - dim2 = get_unit(obj2) + dim2 = get_dim(obj2) if dim1 is not dim2 and not (dim1 is None or dim2 is None): # Special treatment for "0": @@ -779,7 +779,7 @@ def is_unitless(obj) -> bool: dimensionless : `bool` ``True`` if `obj` is dimensionless. """ - return get_unit(obj) is DIMENSIONLESS + return get_dim(obj) is DIMENSIONLESS def is_scalar_type(obj) -> bool: @@ -1105,8 +1105,8 @@ def has_same_unit(self, other): """ if not _unit_checking: return True - other_unit = get_unit(other.dim) - return (get_unit(self.dim) is other_unit) or (get_unit(self.dim) == other_unit) + other_unit = get_dim(other.dim) + return (get_dim(self.dim) is other_unit) or (get_dim(self.dim) == other_unit) def get_best_unit(self, *regs) -> 'Quantity': """ @@ -1475,7 +1475,7 @@ def _binary_operation( _, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other) if other_dim is None: - other_dim = get_unit(other) + other_dim = get_dim(other) new_dim = unit_operation(self.dim, other_dim) result = value_operation(self.value, other.value) @@ -1944,14 +1944,13 @@ def take( self, indices, axis=None, - out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None, ) -> 'Quantity': """Return an array formed from the elements of a at the given indices.""" - return Quantity(jnp.take(self.value, indices=indices, axis=axis, out=out, mode=mode, + return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value), dim=self.dim) @@ -3032,8 +3031,8 @@ def new_f(*args, **kwds): ) raise TypeError(error_message) if not have_same_unit(newkeyset[k], newkeyset[au[k]]): - d1 = get_unit(newkeyset[k]) - d2 = get_unit(newkeyset[au[k]]) + d1 = get_dim(newkeyset[k]) + d2 = get_dim(newkeyset[au[k]]) error_message = ( f"Function '{f.__name__}' expected " f"the argument '{k}' to have the same " @@ -3054,13 +3053,13 @@ def new_f(*args, **kwds): f"'{value}'" ) raise DimensionMismatchError( - error_message, get_unit(newkeyset[k]) + error_message, get_dim(newkeyset[k]) ) result = f(*args, **kwds) if "result" in au: if isinstance(au["result"], Callable) and au["result"] != bool: - expected_result = au["result"](*[get_unit(a) for a in args]) + expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] if au["result"] == bool: @@ -3080,7 +3079,7 @@ def new_f(*args, **kwds): f"unit {unit} but was " f"'{result}'" ) - raise DimensionMismatchError(error_message, get_unit(result)) + raise DimensionMismatchError(error_message, get_dim(result)) return result new_f._orig_func = f diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index fe5c2d7..5380f68 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -36,7 +36,7 @@ check_units, fail_for_dimension_mismatch, get_or_create_dimension, - get_unit, + get_dim, get_basic_unit, have_same_unit, in_unit, @@ -74,7 +74,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" + assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_dim(q)}) ({get_dim(unit)})" if not jnp.allclose(q.value, values): raise AssertionError(f"Values do not match: {q.value} != {values}") elif isinstance(q, jnp.ndarray): @@ -145,19 +145,19 @@ def test_get_dimensions(): Test various ways of getting/comparing the dimensions of a Array. """ q = 500 * ms - assert get_unit(q) is get_or_create_dimension(q.dim._dims) - assert get_unit(q) is q.dim + assert get_dim(q) is get_or_create_dimension(q.dim._dims) + assert get_dim(q) is q.dim assert q.has_same_unit(3 * second) dims = q.dim assert_equal(dims.get_dimension("time"), 1.0) assert_equal(dims.get_dimension("length"), 0) - assert get_unit(5) is DIMENSIONLESS - assert get_unit(5.0) is DIMENSIONLESS - assert get_unit(np.array(5, dtype=np.int32)) is DIMENSIONLESS - assert get_unit(np.array(5.0)) is DIMENSIONLESS - assert get_unit(np.float32(5.0)) is DIMENSIONLESS - assert get_unit(np.float64(5.0)) is DIMENSIONLESS + assert get_dim(5) is DIMENSIONLESS + assert get_dim(5.0) is DIMENSIONLESS + assert get_dim(np.array(5, dtype=np.int32)) is DIMENSIONLESS + assert get_dim(np.array(5.0)) is DIMENSIONLESS + assert get_dim(np.float32(5.0)) is DIMENSIONLESS + assert get_dim(np.float64(5.0)) is DIMENSIONLESS assert is_scalar_type(5) assert is_scalar_type(5.0) assert is_scalar_type(np.array(5, dtype=np.int32)) diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index c47432c..fdf0fcb 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -1276,7 +1276,7 @@ def searchsorted( @set_module_as('brainunit.math') def extract( - condition: Union[Array, Quantity], + condition: Array, arr: Union[Array, Quantity], *, size: Optional[int] = None, diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 7804702..07f9f2a 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -19,7 +19,7 @@ import numpy as np from brainunit._misc import set_module_as -from .._base import Quantity +from .._base import Quantity, fail_for_dimension_mismatch __all__ = [ # math funcs keep unit (unary) @@ -1090,6 +1090,7 @@ def modf( def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs): if isinstance(x1, Quantity) and isinstance(x2, Quantity): + fail_for_dimension_mismatch(x1, x2, func.__name__) return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): return func(x1, x2, *args, **kwargs) @@ -1098,7 +1099,8 @@ def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs): @set_module_as('brainunit.math') -def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: +def fmod(x1: Union[Quantity, jax.Array], + x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: """ Return the element-wise remainder of division. @@ -1158,7 +1160,8 @@ def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> @set_module_as('brainunit.math') -def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: +def heaviside(x1: Union[Quantity, jax.Array], + x2: jax.typing.ArrayLike) -> Union[Quantity, jax.Array]: """ Compute the Heaviside step function. @@ -1174,7 +1177,8 @@ def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> out : jax.Array, Quantity Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. """ - return funcs_keep_unit_binary(jnp.heaviside, x1, x2) + x1 = x1.value if isinstance(x1, Quantity) else x1 + return jnp.heaviside(x1, x2) @set_module_as('brainunit.math') @@ -1300,12 +1304,14 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union # math funcs keep unit (n-ary) # ---------------------------- @set_module_as('brainunit.math') -def interp(x: Union[Quantity, jax.typing.ArrayLike], - xp: Union[Quantity, jax.typing.ArrayLike], - fp: Union[Quantity, jax.typing.ArrayLike], - left: Union[Quantity, jax.typing.ArrayLike] = None, - right: Union[Quantity, jax.typing.ArrayLike] = None, - period: Union[Quantity, jax.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: +def interp( + x: Union[Quantity, jax.typing.ArrayLike], + xp: Union[Quantity, jax.typing.ArrayLike], + fp: Union[Quantity, jax.typing.ArrayLike], + left: Union[Quantity, jax.typing.ArrayLike] = None, + right: Union[Quantity, jax.typing.ArrayLike] = None, + period: Union[Quantity, jax.typing.ArrayLike] = None +) -> Union[Quantity, jax.Array]: """ One-dimensional linear interpolation. @@ -1343,9 +1349,11 @@ def interp(x: Union[Quantity, jax.typing.ArrayLike], @set_module_as('brainunit.math') -def clip(a: Union[Quantity, jax.typing.ArrayLike], - a_min: Union[Quantity, jax.typing.ArrayLike], - a_max: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]: +def clip( + a: Union[Quantity, jax.typing.ArrayLike], + a_min: Union[Quantity, jax.typing.ArrayLike], + a_max: Union[Quantity, jax.typing.ArrayLike] +) -> Union[Quantity, jax.Array]: """ Clip (limit) the values in an array. diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py index ae4df30..0be3fcf 100644 --- a/brainunit/math/_compat_numpy_funcs_window.py +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -19,7 +19,6 @@ from brainunit._misc import set_module_as __all__ = [ - # window funcs 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', ] diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 5bfdd1d..dab1427 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from __future__ import annotations + from collections.abc import Sequence from typing import (Callable, Union, Tuple, Any, Optional) @@ -30,7 +33,7 @@ Quantity, fail_for_dimension_mismatch, is_unitless, - get_unit, ) + get_dim, ) __all__ = [ @@ -77,14 +80,9 @@ def iinfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.iinfo: # ---- @set_module_as('brainunit.math') def broadcast_arrays(*args: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, list[Array]]: - if all(isinstance(arg, Quantity) for arg in args): - if any(arg.dim != args[0].dim for arg in args): - raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) - elif all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): - return jnp.broadcast_arrays(*args) - else: - raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + leaves, tree = jax.tree.flatten(args) + leaves = jnp.broadcast_arrays(*leaves) + return jax.tree.unflatten(tree, leaves) broadcast_shapes = jnp.broadcast_shapes @@ -149,13 +147,11 @@ def einsum( from jax._src.numpy.lax_numpy import _default_poly_einsum_handler contract_path = _default_poly_einsum_handler - operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) + operands, contractions = contract_path(*operands, einsum_call=True, use_blas=True, optimize=optimize) unit = None for i in range(len(contractions) - 1): if contractions[i][4] == 'False': - fail_for_dimension_mismatch( Quantity([], dim=unit), operands[i + 1], 'einsum' ) @@ -248,13 +244,13 @@ def gradient( else: return jnp.gradient(f) elif len(varargs) == 1: - unit = get_unit(f) / get_unit(varargs[0]) + unit = get_dim(f) / get_dim(varargs[0]) if unit is None or unit == DIMENSIONLESS: return jnp.gradient(f, varargs[0], axis=axis) else: return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] else: - unit_list = [get_unit(f) / get_unit(v) for v in varargs] + unit_list = [get_dim(f) / get_dim(v) for v in varargs] f = f.value if isinstance(f, Quantity) else f varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] result_list = jnp.gradient(f, *varargs, axis=axis) @@ -307,7 +303,7 @@ def intersect1d( result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) if return_indices: if unit is not None: - return (Quantity(result[0], dim=unit), result[1], result[2]) + return Quantity(result[0], dim=unit), result[1], result[2] else: return result else: @@ -320,9 +316,9 @@ def intersect1d( @set_module_as('brainunit.math') def nan_to_num( x: Union[jax.typing.ArrayLike, Quantity], - nan: float = 0.0, - posinf: float = jnp.inf, - neginf: float = -jnp.inf + nan: float | Quantity = 0.0, + posinf: float | Quantity = jnp.inf, + neginf: float | Quantity = -jnp.inf ) -> Union[jax.Array, Quantity]: """ Replace NaN with zero and infinity with large finite numbers (default @@ -472,9 +468,6 @@ def nanargmax( Input data. axis : int, optional Axis along which to operate. By default flattened input is used. - out : array, optional - If provided, the result will be inserted into this array. It should - be of the appropriate shape and dtype. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, @@ -509,9 +502,6 @@ def nanargmin( Input data. axis : int, optional Axis along which to operate. By default flattened input is used. - out : array, optional - If provided, the result will be inserted into this array. It should - be of the appropriate shape and dtype. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, @@ -534,8 +524,6 @@ def frexp( x: Union[Quantity, jax.typing.ArrayLike] ) -> Tuple[jax.Array, jax.Array]: """ - frexp(x[, out1, out2], / [, out=(None, None)], *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) - Decompose the elements of x into mantissa and twos exponent. Returns (`mantissa`, `exponent`), where ``x = mantissa * 2**exponent``. @@ -556,4 +544,6 @@ def frexp( Integer exponents of 2. This is a scalar if `x` is a scalar. """ + assert not isinstance(x, Quantity) or is_unitless(x), "Input must be unitless" + x = x.value if isinstance(x, Quantity) else x return jnp.frexp(x)