diff --git a/brainunit/_base.py b/brainunit/_base.py index 5b7b02d..ea57abd 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -98,7 +98,7 @@ def _short_str(arr): Return a short string representation of an array, suitable for use in error messages. """ - arr = arr._mantissa if isinstance(arr, Quantity) else arr + arr = arr.mantissa if isinstance(arr, Quantity) else arr if not isinstance(arr, (jax.core.Tracer, jax.core.ShapedArray, jax.ShapeDtypeStruct)): arr = np.asanyarray(arr) with change_printoption(edgeitems=2, threshold=5): @@ -994,8 +994,8 @@ def unit_scale_align_to_first(*args) -> List['Quantity']: @set_module_as('brainunit') def array_with_unit( - floatval, - dim: Dimension, + mantissa, + unit: 'Unit', dtype: Optional[jax.typing.DTypeLike] = None ) -> 'Quantity': """ @@ -1007,9 +1007,9 @@ def array_with_unit( Parameters ---------- - floatval : `float` + mantissa : `float` The floating point value of the array. - dim: Dimension + unit: Unit The dim dimensions of the array. dtype: `dtype`, optional The data type of the array. @@ -1022,11 +1022,11 @@ def array_with_unit( Examples -------- >>> from brainunit import * - >>> array_with_unit(0.001, volt.dim) + >>> array_with_unit(0.001, volt) 1. * mvolt """ - assert isinstance(dim, Dimension), f'Expected instance of Dimension, but got {dim}' - return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype) + assert isinstance(unit, Unit), f'Expected instance of Unit, but got {unit}' + return Quantity(mantissa, unit=unit, dtype=dtype) @set_module_as('brainunit') @@ -1103,7 +1103,7 @@ def _wrap_function_keep_unit(func): """ def f(x: Quantity, *args, **kwds): # pylint: disable=C0111 - return Quantity(func(x._mantissa, *args, **kwds), unit=x.unit) + return Quantity(func(x.mantissa, *args, **kwds), unit=x.unit) f._arg_units = [None] f._return_unit = lambda u: u @@ -1878,7 +1878,7 @@ def _check_units_and_collect_values(lst) -> Tuple[jax.typing.ArrayLike, 'Unit']: if unit != UNITLESS: units.append(unit) elif isinstance(item, Quantity): - values.append(item._mantissa) + values.append(item.mantissa) units.append(item.unit) elif isinstance(item, Unit): values.append(1) @@ -1956,7 +1956,7 @@ def __init__( elif not unit.has_same_dim(mantissa.unit): raise ValueError("Cannot create a Quantity object with a different unit.") mantissa = mantissa.in_unit(unit) - mantissa = mantissa._mantissa + mantissa = mantissa.mantissa elif isinstance(mantissa, (np.ndarray, jax.Array)): if dtype is not None: @@ -2157,9 +2157,9 @@ def to_decimal(self, unit: Unit = UNITLESS) -> jax.typing.ArrayLike: f"dimensions. The quantity has the unit {self.unit}, but the given " f"unit is {unit}") if not unit.has_same_scale(self.unit): - return self._mantissa * (self.unit.value / unit.value) + return self.mantissa * (self.unit.value / unit.value) else: - return self._mantissa + return self.mantissa def in_unit(self, unit: Unit, err_msg: str = None) -> 'Quantity': """ @@ -2173,9 +2173,9 @@ def in_unit(self, unit: Unit, err_msg: str = None) -> 'Quantity': else: raise UnitMismatchError(err_msg) if unit.has_same_scale(self.unit): - u = Quantity(self._mantissa, unit=unit) + u = Quantity(self.mantissa, unit=unit) else: - u = Quantity(self._mantissa * (self.unit.value / unit.value), unit=unit) + u = Quantity(self.mantissa * (self.unit.value / unit.value), unit=unit) return u @staticmethod @@ -2267,27 +2267,34 @@ def repr_in_unit( >>> x.repr_in_unit(mV, 3) '25.123 mV' """ - value = jnp.asarray(self.mantissa) - if _is_tracer(value): - # in the JIT mode + # convert to the JAX array + try: + value = jnp.asarray(self.mantissa) + except TypeError: + value = self.mantissa + + if _is_tracer(value): # in the JIT mode s = str(value) - else: - if value.shape == (): - s = np.array_str(np.array([value]), precision=precision) - s = s.replace("[", "").replace("]", "").strip() - else: - if value.size > 100: - if python_code: - s = np.array_repr(value, precision=precision)[:100] - s += "..." - else: - s = np.array_str(value, precision=precision)[:100] - s += "..." + else: # in the normal mode + try: + if value.shape == (): + s = np.array_str(np.array([value]), precision=precision) + s = s.replace("[", "").replace("]", "").strip() else: - if python_code: - s = np.array_repr(value, precision=precision) + if value.size > 100: + if python_code: + s = np.array_repr(value, precision=precision)[:100] + s += "..." + else: + s = np.array_str(value, precision=precision)[:100] + s += "..." else: - s = np.array_str(value, precision=precision) + if python_code: + s = np.array_repr(value, precision=precision) + else: + s = np.array_str(value, precision=precision) + except TypeError: + s = str(value) if not self.unit.is_unitless: if python_code: @@ -2299,7 +2306,7 @@ def repr_in_unit( return s.strip() def _check_tracer(self): - self_value = self._mantissa + self_value = self.mantissa # if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'): # if len(self_value._trace.main.jaxpr_stack) == 0: # raise RuntimeError('This Array is modified during the transformation. ' @@ -2310,7 +2317,7 @@ def _check_tracer(self): @property def dtype(self): """Variable dtype.""" - a = self._mantissa + a = self.mantissa if hasattr(a, 'dtype'): return a.dtype else: @@ -2328,31 +2335,31 @@ def dtype(self): @property def shape(self) -> Tuple[int, ...]: """Variable shape.""" - return jnp.shape(self._mantissa) + return jnp.shape(self.mantissa) @property def ndim(self) -> int: - return jnp.ndim(self._mantissa) + return jnp.ndim(self.mantissa) @property def imag(self) -> 'Quantity': - return Quantity(jnp.imag(self._mantissa), unit=self.unit) + return Quantity(jnp.imag(self.mantissa), unit=self.unit) @property def real(self) -> 'Quantity': - return Quantity(jnp.real(self._mantissa), unit=self.unit) + return Quantity(jnp.real(self.mantissa), unit=self.unit) @property def size(self) -> int: - return jnp.size(self._mantissa) + return jnp.size(self.mantissa) @property def T(self) -> 'Quantity': - return Quantity(jnp.asarray(self._mantissa).T, unit=self.unit) + return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit) @property def isreal(self) -> jax.Array: - return jnp.isreal(self._mantissa) + return jnp.isreal(self.mantissa) @property def isscalar(self) -> bool: @@ -2360,19 +2367,19 @@ def isscalar(self) -> bool: @property def isfinite(self) -> jax.Array: - return jnp.isfinite(self._mantissa) + return jnp.isfinite(self.mantissa) @property def isinfnite(self) -> jax.Array: - return jnp.isinf(self._mantissa) + return jnp.isinf(self.mantissa) @property def isinf(self) -> jax.Array: - return jnp.isinf(self._mantissa) + return jnp.isinf(self.mantissa) @property def isnan(self) -> jax.Array: - return jnp.isnan(self._mantissa) + return jnp.isnan(self.mantissa) # ----------------------- # # Python inherent methods # @@ -2399,17 +2406,17 @@ def __iter__(self): yield self else: for i in range(self.shape[0]): - yield Quantity(self._mantissa[i], unit=self.unit) + yield Quantity(self.mantissa[i], unit=self.unit) def __getitem__(self, index) -> 'Quantity': if isinstance(index, slice) and (index == _all_slice): - return Quantity(self._mantissa, unit=self.unit) + return Quantity(self.mantissa, unit=self.unit) elif isinstance(index, tuple): for x in index: assert not isinstance(x, Quantity), "Array indices must be integers or slices, not Array" elif isinstance(index, Quantity): raise TypeError("Array indices must be integers or slices, not Array") - return Quantity(self._mantissa[index], unit=self.unit) + return Quantity(self.mantissa[index], unit=self.unit) def __setitem__(self, index, value: 'Quantity' | jax.typing.ArrayLike): # check value @@ -2635,19 +2642,19 @@ def scatter_min( # ---------- # def __len__(self) -> int: - return len(self._mantissa) + return len(self.mantissa) def __neg__(self) -> 'Quantity': - return Quantity(self._mantissa.__neg__(), unit=self.unit) + return Quantity(self.mantissa.__neg__(), unit=self.unit) def __pos__(self) -> 'Quantity': - return Quantity(self._mantissa.__pos__(), unit=self.unit) + return Quantity(self.mantissa.__pos__(), unit=self.unit) def __abs__(self) -> 'Quantity': - return Quantity(self._mantissa.__abs__(), unit=self.unit) + return Quantity(self.mantissa.__abs__(), unit=self.unit) def __invert__(self) -> 'Quantity': - return Quantity(self._mantissa.__invert__(), unit=self.unit) + return Quantity(self.mantissa.__invert__(), unit=self.unit) def _comparison(self, other: Any, operator_str: str, operation: Callable): other = _to_quantity(other) @@ -2716,7 +2723,7 @@ def _binary_operation( if fail_for_mismatch: other = other.in_unit(self.unit, err_msg=f"Cannot calculate \n" - f"{self} {operator_str} {other}" + f"{self} {operator_str} {other}, " f"because units do not match: {self.unit} != {other.unit}") other_value = other.mantissa other_unit = other.unit @@ -2853,7 +2860,7 @@ def __pow__(self, oc): def __rpow__(self, oc): # oc ** self assert self.is_unitless, f"Cannot calculate {oc} ** {self}, the exponent has to be dimensionless" - return oc ** self._mantissa + return oc ** self.mantissa def __ipow__(self, oc): # a **= b @@ -2901,45 +2908,59 @@ def __lshift__(self, oc) -> 'Quantity': # self << oc if isinstance(oc, Quantity): assert oc.is_unitless, "The shift amount must be dimensionless" - oc = oc._mantissa - r = Quantity(self._mantissa << oc, unit=self.unit) + oc = oc.mantissa + r = Quantity(self.mantissa << oc, unit=self.unit) return maybe_decimal(r) def __rlshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike: # oc << self assert self.is_unitless, "The shift amount must be dimensionless" - return oc << self._mantissa + return oc << self.mantissa def __ilshift__(self, oc) -> 'Quantity': # self <<= oc r = self.__lshift__(oc) - self.update_value(r._mantissa) + self.update_value(r.mantissa) return self def __rshift__(self, oc) -> 'Quantity': # self >> oc if isinstance(oc, Quantity): assert oc.is_unitless, "The shift amount must be dimensionless" - oc = oc._mantissa - r = Quantity(self._mantissa >> oc, unit=self.unit) + oc = oc.mantissa + r = Quantity(self.mantissa >> oc, unit=self.unit) return maybe_decimal(r) def __rrshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike: # oc >> self assert self.is_unitless, "The shift amount must be dimensionless" - return oc >> self._mantissa + return oc >> self.mantissa def __irshift__(self, oc) -> 'Quantity': # self >>= oc r = self.__rshift__(oc) - self.update_value(r._mantissa) + self.update_value(r.mantissa) return self def __round__(self, ndigits: int = None) -> 'Quantity': - return Quantity(self._mantissa.__round__(ndigits), unit=self.unit) + """ + Round the mantissa to the given number of decimals. + + :param ndigits: The number of decimals to round to. + :return: The rounded Quantity. + """ + return Quantity(self.mantissa.__round__(ndigits), unit=self.unit) def __reduce__(self): - return array_with_unit, (self._mantissa, self.unit, None) + """ + Method used by Pickle object serialization. + + Returns + ------- + tuple + The tuple of the class and the arguments required to reconstruct the object. + """ + return array_with_unit, (self.mantissa, self.unit, None) # ----------------------- # # NumPy methods # @@ -2966,7 +2987,7 @@ def __reduce__(self): ravel = _wrap_function_keep_unit(jnp.ravel) def __deepcopy__(self, memodict: Dict): - return Quantity(deepcopy(self._mantissa), unit=self.unit.__deepcopy__(memodict)) + return Quantity(deepcopy(self.mantissa), unit=self.unit.__deepcopy__(memodict)) def round( self, @@ -2992,7 +3013,7 @@ def round( The real and imaginary parts of complex numbers are rounded separately. The result of rounding a float is a float. """ - return Quantity(jnp.round(self._mantissa, decimals), unit=self.unit) + return Quantity(jnp.round(self.mantissa, decimals), unit=self.unit) def astype( self, @@ -3006,9 +3027,9 @@ def astype( Typecode or data-type to which the array is cast. """ if dtype is None: - return Quantity(self._mantissa, unit=self.unit) + return Quantity(self.mantissa, unit=self.unit) else: - return Quantity(jnp.astype(self._mantissa, dtype), unit=self.unit) + return Quantity(jnp.astype(self.mantissa, dtype), unit=self.unit) def clip( self, @@ -3020,19 +3041,19 @@ def clip( """ _, min = unit_scale_align_to_first(self, min) _, max = unit_scale_align_to_first(self, max) - return Quantity(jnp.clip(self._mantissa, min._mantissa, max._mantissa), unit=self.unit) + return Quantity(jnp.clip(self.mantissa, min.mantissa, max.mantissa), unit=self.unit) def conj(self) -> 'Quantity': """Complex-conjugate all elements.""" - return Quantity(jnp.conj(self._mantissa), unit=self.unit) + return Quantity(jnp.conj(self.mantissa), unit=self.unit) def conjugate(self) -> 'Quantity': """Return the complex conjugate, element-wise.""" - return Quantity(jnp.conjugate(self._mantissa), unit=self.unit) + return Quantity(jnp.conjugate(self.mantissa), unit=self.unit) def copy(self) -> 'Quantity': """Return a copy of the quantity.""" - return type(self)(jnp.copy(self._mantissa), unit=self.unit) + return type(self)(jnp.copy(self.mantissa), unit=self.unit) def dot(self, b) -> 'Quantity': """Dot product of two arrays.""" @@ -3046,15 +3067,15 @@ def fill(self, value: Quantity) -> 'Quantity': return self def flatten(self) -> 'Quantity': - return Quantity(jnp.reshape(self._mantissa, -1), unit=self.unit) + return Quantity(jnp.reshape(self.mantissa, -1), unit=self.unit) def item(self, *args) -> 'Quantity': """Copy an element of an array to a standard Python scalar and return it.""" - return Quantity(self._mantissa.item(*args), unit=self.unit) + return Quantity(self.mantissa.item(*args), unit=self.unit) def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is not None """Return the product of the array elements over the given axis.""" - prod_res = jnp.prod(self._mantissa, *args, **kwds) + prod_res = jnp.prod(self.mantissa, *args, **kwds) # Calculating the correct dimensions is not completly trivial (e.g. # like doing self.dim**self.size) because prod can be called on # multidimensional arrays along a certain axis. @@ -3063,7 +3084,7 @@ def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is n # The result gives the exponent for the dimensions. # This relies on sum and prod having the same arguments, which is true # now and probably remains like this in the future - dim_exponent = jnp.ones_like(self._mantissa).sum(*args, **kwds) + dim_exponent = jnp.ones_like(self.mantissa).sum(*args, **kwds) # The result is possibly multidimensional but all entries should be # identical if dim_exponent.size > 1: @@ -3073,8 +3094,8 @@ def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is n def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is not None """Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.""" - prod_res = jnp.nanprod(self._mantissa, *args, **kwds) - nan_mask = jnp.isnan(self._mantissa) + prod_res = jnp.nanprod(self.mantissa, *args, **kwds) + nan_mask = jnp.isnan(self.mantissa) dim_exponent = jnp.cumsum(jnp.where(nan_mask, 0, 1), *args) if dim_exponent.size > 1: dim_exponent = dim_exponent[-1] @@ -3082,16 +3103,16 @@ def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis i return maybe_decimal(r) def cumprod(self, *args, **kwds): # TODO: check error when axis is not None - prod_res = jnp.cumprod(self._mantissa, *args, **kwds) - dim_exponent = jnp.ones_like(self._mantissa).cumsum(*args, **kwds) + prod_res = jnp.cumprod(self.mantissa, *args, **kwds) + dim_exponent = jnp.ones_like(self.mantissa).cumsum(*args, **kwds) if dim_exponent.size > 1: dim_exponent = dim_exponent[-1] r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent) return maybe_decimal(r) def nancumprod(self, *args, **kwds): # TODO: check error when axis is not None - prod_res = jnp.nancumprod(self._mantissa, *args, **kwds) - nan_mask = jnp.isnan(self._mantissa) + prod_res = jnp.nancumprod(self.mantissa, *args, **kwds) + nan_mask = jnp.isnan(self.mantissa) dim_exponent = jnp.cumsum(jnp.where(nan_mask, 0, 1), *args) if dim_exponent.size > 1: dim_exponent = dim_exponent[-1] @@ -3114,16 +3135,16 @@ def put(self, indices, values) -> 'Quantity': def repeat(self, repeats, axis=None) -> 'Quantity': """Repeat elements of an array.""" - r = jnp.repeat(self._mantissa, repeats=repeats, axis=axis) + r = jnp.repeat(self.mantissa, repeats=repeats, axis=axis) return Quantity(r, unit=self.unit) def reshape(self, *shape, order='C') -> 'Quantity': """Returns an array containing the same data with a new shape.""" - return Quantity(jnp.reshape(self._mantissa, shape, order=order), unit=self.unit) + return Quantity(jnp.reshape(self.mantissa, shape, order=order), unit=self.unit) def resize(self, new_shape) -> 'Quantity': """Change shape and size of array in-place.""" - self.update_value(jnp.resize(self._mantissa, new_shape)) + self.update_value(jnp.resize(self.mantissa, new_shape)) return self def sort(self, axis=-1, stable=True, order=None) -> 'Quantity': @@ -3143,16 +3164,16 @@ def sort(self, axis=-1, stable=True, order=None) -> 'Quantity': but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. """ - self.update_value(jnp.sort(self._mantissa, axis=axis, stable=stable, order=order)) + self.update_value(jnp.sort(self.mantissa, axis=axis, stable=stable, order=order)) return self def squeeze(self, axis=None) -> 'Quantity': """Remove axes of length one from ``a``.""" - return Quantity(jnp.squeeze(self._mantissa, axis=axis), unit=self.unit) + return Quantity(jnp.squeeze(self.mantissa, axis=axis), unit=self.unit) def swapaxes(self, axis1, axis2) -> 'Quantity': """Return a view of the array with `axis1` and `axis2` interchanged.""" - return Quantity(jnp.swapaxes(self._mantissa, axis1, axis2), unit=self.unit) + return Quantity(jnp.swapaxes(self.mantissa, axis1, axis2), unit=self.unit) def split(self, indices_or_sections, axis=0) -> List['Quantity']: """Split an array into multiple sub-arrays as views into ``ary``. @@ -3182,7 +3203,7 @@ def split(self, indices_or_sections, axis=0) -> List['Quantity']: sub-arrays : list of ndarrays A list of sub-arrays as views into `ary`. """ - return [Quantity(a, unit=self.unit) for a in jnp.split(self._mantissa, indices_or_sections, axis=axis)] + return [Quantity(a, unit=self.unit) for a in jnp.split(self.mantissa, indices_or_sections, axis=axis)] def take( self, @@ -3196,13 +3217,13 @@ def take( """Return an array formed from the elements of a at the given indices.""" if isinstance(fill_value, Quantity): fail_for_dimension_mismatch(self, fill_value, "take") - fill_value = unit_scale_align_to_first(self, fill_value)[1]._mantissa + fill_value = unit_scale_align_to_first(self, fill_value)[1].mantissa elif fill_value is not None: if not self.is_unitless: raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}") return Quantity( jnp.take( - self._mantissa, + self.mantissa, indices=indices, axis=axis, mode=mode, @@ -3223,7 +3244,7 @@ def tolist(self): If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will not be a list at all, but a simple Python scalar. """ - return _replace_with_array(self._mantissa.tolist(), self.unit) + return _replace_with_array(self.mantissa.tolist(), self.unit) def transpose(self, *axes) -> 'Quantity': """Returns a view of the array with axes transposed. @@ -3255,7 +3276,7 @@ def transpose(self, *axes) -> 'Quantity': out : ndarray View of `a`, with axes suitably permuted. """ - return Quantity(jnp.transpose(self._mantissa, *axes), unit=self.unit) + return Quantity(jnp.transpose(self.mantissa, *axes), unit=self.unit) def tile(self, reps) -> 'Quantity': """Construct an array by repeating A the number of times given by reps. @@ -3286,7 +3307,7 @@ def tile(self, reps) -> 'Quantity': c : ndarray The tiled output array. """ - return Quantity(jnp.tile(self._mantissa, reps), unit=self.unit) + return Quantity(jnp.tile(self.mantissa, reps), unit=self.unit) def view(self, *args, dtype=None) -> 'Quantity': r"""New view of array with the same data. @@ -3428,16 +3449,16 @@ def view(self, *args, dtype=None) -> 'Quantity': if dtype is None: raise ValueError('Provide dtype or shape.') else: - return Quantity(self._mantissa.view(dtype), unit=self.unit) + return Quantity(self.mantissa.view(dtype), unit=self.unit) else: if isinstance(args[0], int): # shape if dtype is not None: raise ValueError('Provide one of dtype or shape. Not both.') - return Quantity(self._mantissa.reshape(*args), unit=self.unit) + return Quantity(self.mantissa.reshape(*args), unit=self.unit) else: # dtype assert not isinstance(args[0], int) assert dtype is None - return Quantity(self._mantissa.view(args[0]), unit=self.unit) + return Quantity(self.mantissa.view(args[0]), unit=self.unit) # ------------------ # NumPy support @@ -3445,26 +3466,26 @@ def view(self, *args, dtype=None) -> 'Quantity': def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" - if self.unit.is_unitless: - return np.asarray(self._mantissa, dtype=dtype) + if self.dim.is_dimensionless: + return np.asarray(self.to_decimal(), dtype=dtype) else: raise TypeError( - f"only dimensionless quantities can be " + f"Only dimensionless quantities can be " f"converted to NumPy arrays. But got {self}" ) def __float__(self): - if self.unit.is_unitless and self.ndim == 0: - return float(self._mantissa) + if self.dim.is_dimensionless and self.ndim == 0: + return float(self.to_decimal()) else: raise TypeError( - "only dimensionless scalar quantities can be " + "Only dimensionless scalar quantities can be " f"converted to Python scalars. But got {self}" ) def __int__(self): - if self.unit.is_unitless and self.ndim == 0: - return int(self._mantissa) + if self.dim.is_dimensionless and self.ndim == 0: + return int(self.to_decimal()) else: raise TypeError( "only dimensionless scalar quantities can be " @@ -3472,8 +3493,8 @@ def __int__(self): ) def __index__(self): - if self.unit.is_unitless: - return operator.index(self._mantissa) + if self.dim.is_dimensionless: + return operator.index(self.to_decimal()) else: raise TypeError( "only dimensionless quantities can be " @@ -3492,7 +3513,7 @@ def unsqueeze(self, axis: int) -> 'Quantity': See :func:`brainstate.math.unsqueeze` """ - return Quantity(jnp.expand_dims(self._mantissa, axis), unit=self.unit) + return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity': """ @@ -3508,7 +3529,7 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity': expanded : Quantity A view with the new axis inserted. """ - return Quantity(jnp.expand_dims(self._mantissa, axis), unit=self.unit) + return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit) def expand_as(self, array: Union['Quantity', jax.typing.ArrayLike]) -> 'Quantity': """ @@ -3527,8 +3548,8 @@ def expand_as(self, array: Union['Quantity', jax.typing.ArrayLike]) -> 'Quantity """ if isinstance(array, Quantity): fail_for_dimension_mismatch(self, array, "expand_as (Quantity)") - array = array._mantissa - return Quantity(jnp.broadcast_to(self._mantissa, array), unit=self.unit) + array = array.mantissa + return Quantity(jnp.broadcast_to(self.mantissa, array), unit=self.unit) def pow(self, oc) -> 'Quantity': return self.__pow__(oc) @@ -3543,7 +3564,7 @@ def tree_flatten(self) -> Tuple[Tuple[jax.typing.ArrayLike], Unit]: Returns: The data and the dimension. """ - return (self._mantissa,), self.unit + return (self.mantissa,), self.unit @classmethod def tree_unflatten(cls, unit, values) -> 'Quantity': @@ -3561,24 +3582,24 @@ def tree_unflatten(cls, unit, values) -> 'Quantity': def cuda(self, deice=None) -> 'Quantity': deice = jax.devices('cuda')[0] if deice is None else deice - self.update_value(jax.device_put(self._mantissa, deice)) + self.update_value(jax.device_put(self.mantissa, deice)) return self def cpu(self, device=None) -> 'Quantity': device = jax.devices('cpu')[0] if device is None else device - self.update_value(jax.device_put(self._mantissa, device)) + self.update_value(jax.device_put(self.mantissa, device)) return self # dtype exchanging # # ---------------- # def half(self) -> 'Quantity': - return Quantity(jnp.asarray(self._mantissa, dtype=jnp.float16), unit=self.unit) + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float16), unit=self.unit) def float(self) -> 'Quantity': - return Quantity(jnp.asarray(self._mantissa, dtype=jnp.float32), unit=self.unit) + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float32), unit=self.unit) def double(self) -> 'Quantity': - return Quantity(jnp.asarray(self._mantissa, dtype=jnp.float64), unit=self.unit) + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float64), unit=self.unit) class _IndexUpdateHelper: diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 70b166c..4887aa4 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1201,7 +1201,7 @@ def a_function(v, x): a_function(5 * second, None) with pytest.raises(DimensionMismatchError): a_function(5, None) - with pytest.raises(TypeError): + with pytest.raises(AttributeError): a_function(object(), None) with pytest.raises(TypeError): a_function([1, 2 * volt, 3], None) @@ -1269,7 +1269,7 @@ def a_function(v, x): a_function(5 * second, None) with pytest.raises(bu.UnitMismatchError): a_function(5, None) - with pytest.raises(TypeError): + with pytest.raises(AttributeError): a_function(object(), None) with pytest.raises(TypeError): a_function([1, 2 * volt, 3], None)