Skip to content

Commit

Permalink
fix Quantity.take() function
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 16, 2024
1 parent 7651621 commit d538812
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,10 +1988,19 @@ def take(
) -> 'Quantity':
"""Return an array formed from the elements of a at the given indices."""
if isinstance(fill_value, Quantity):
fail_for_dimension_mismatch(self, fill_value, "take")
fill_value = fill_value.value
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value), dim=self.dim)
elif fill_value is not None:
if not self.is_unitless:
raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}")
return Quantity(
jnp.take(self.value,
indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value),
dim=self.dim
)

def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Expand Down Expand Up @@ -2251,9 +2260,11 @@ def view(self, *args, dtype=None) -> 'Quantity':
# NumPy support
# ------------------

def to_numpy(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> np.ndarray:
def to_numpy(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> np.ndarray:
"""
Remove the unit and convert to ``numpy.ndarray``.
Expand All @@ -2265,14 +2276,19 @@ def to_numpy(self,
The numpy.ndarray.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"NumPy arrays when 'unit' is not provided. But got {self}")
return np.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_numpy")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return np.asarray(self / unit, dtype=dtype)

def to_jax(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> jax.Array:
def to_jax(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> jax.Array:
"""
Remove the unit and convert to ``jax.Array``.
Expand All @@ -2284,8 +2300,11 @@ def to_jax(self,
The jax.Array.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"JAX arrays when 'unit' is not provided. But got {self}")
return jnp.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_jax")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return jnp.asarray(self / unit, dtype=dtype)

Expand All @@ -2301,7 +2320,7 @@ def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray:

def __float__(self):
if self.dim == DIMENSIONLESS and self.ndim == 0:
return self.value.__float__()
return float(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
Expand All @@ -2310,7 +2329,7 @@ def __float__(self):

def __int__(self):
if self.dim == DIMENSIONLESS and self.ndim == 0:
return self.value.__int__()
return int(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
Expand Down

0 comments on commit d538812

Please sign in to comment.