From d538812a2c69d4d49deaf58e4831874fa3ba25d7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 16 Jun 2024 17:01:05 +0800 Subject: [PATCH] fix `Quantity.take()` function --- brainunit/_base.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index d2d6a97..24d0ffa 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -1988,10 +1988,19 @@ def take( ) -> 'Quantity': """Return an array formed from the elements of a at the given indices.""" if isinstance(fill_value, Quantity): + fail_for_dimension_mismatch(self, fill_value, "take") fill_value = fill_value.value - return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode, - unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, - fill_value=fill_value), dim=self.dim) + elif fill_value is not None: + if not self.is_unitless: + raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}") + return Quantity( + jnp.take(self.value, + indices=indices, axis=axis, mode=mode, + unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, + fill_value=fill_value), + dim=self.dim + ) def tolist(self): """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. @@ -2251,9 +2260,11 @@ def view(self, *args, dtype=None) -> 'Quantity': # NumPy support # ------------------ - def to_numpy(self, - dtype: Optional[jax.typing.DTypeLike] = None, - unit: Optional['Unit'] = None) -> np.ndarray: + def to_numpy( + self, + unit: Optional['Unit'] = None, + dtype: Optional[jax.typing.DTypeLike] = None, + ) -> np.ndarray: """ Remove the unit and convert to ``numpy.ndarray``. @@ -2265,14 +2276,19 @@ def to_numpy(self, The numpy.ndarray. """ if unit is None: + assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to " + f"NumPy arrays when 'unit' is not provided. But got {self}") return np.asarray(self.value, dtype=dtype) else: + fail_for_dimension_mismatch(self, unit, "to_numpy") assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" return np.asarray(self / unit, dtype=dtype) - def to_jax(self, - dtype: Optional[jax.typing.DTypeLike] = None, - unit: Optional['Unit'] = None) -> jax.Array: + def to_jax( + self, + unit: Optional['Unit'] = None, + dtype: Optional[jax.typing.DTypeLike] = None, + ) -> jax.Array: """ Remove the unit and convert to ``jax.Array``. @@ -2284,8 +2300,11 @@ def to_jax(self, The jax.Array. """ if unit is None: + assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to " + f"JAX arrays when 'unit' is not provided. But got {self}") return jnp.asarray(self.value, dtype=dtype) else: + fail_for_dimension_mismatch(self, unit, "to_jax") assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}" return jnp.asarray(self / unit, dtype=dtype) @@ -2301,7 +2320,7 @@ def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: def __float__(self): if self.dim == DIMENSIONLESS and self.ndim == 0: - return self.value.__float__() + return float(self.value) else: raise TypeError( "only dimensionless scalar quantities can be " @@ -2310,7 +2329,7 @@ def __float__(self): def __int__(self): if self.dim == DIMENSIONLESS and self.ndim == 0: - return self.value.__int__() + return int(self.value) else: raise TypeError( "only dimensionless scalar quantities can be "