From 337b3658ecb921c77b870953f8b232857cb3637a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:25:58 +0800 Subject: [PATCH] Update --- brainunit/math/__init__.py | 18 +++++++++-- .../math/_compat_numpy_array_creation.py | 30 +++++++++---------- .../math/_compat_numpy_array_manipulation.py | 2 +- .../math/_compat_numpy_funcs_change_unit.py | 28 ++++++++--------- .../math/_compat_numpy_funcs_indexing.py | 6 ++-- .../math/_compat_numpy_funcs_keep_unit.py | 12 ++++---- .../math/_compat_numpy_funcs_match_unit.py | 6 ++-- brainunit/math/_compat_numpy_misc.py | 26 ++++++++-------- 8 files changed, 70 insertions(+), 58 deletions(-) diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 03fd080..e574603 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -46,7 +46,6 @@ from ._compat_numpy_misc import * from ._compat_numpy_misc import __all__ as _compat_misc_all - __all__ = _compat_array_creation_all + \ _compat_array_manipulation_all + \ _compat_funcs_change_unit_all + \ @@ -63,5 +62,18 @@ _compat_misc_all + _other_all + \ _other_all -del _compat_array_creation_all, _compat_array_manipulation_all, _compat_funcs_change_unit_all, _compat_funcs_keep_unit_all, _compat_funcs_accept_unitless_all, _compat_funcs_match_unit_all, _compat_funcs_remove_unit_all, _compat_get_attribute_all, _compat_funcs_bit_operation_all, _compat_funcs_logic_all, _compat_funcs_indexing_all, _compat_funcs_window_all, _compat_linear_algebra_all, _compat_misc_all, _other_all - +del _compat_array_creation_all, \ + _compat_array_manipulation_all, \ + _compat_funcs_change_unit_all, \ + _compat_funcs_keep_unit_all, \ + _compat_funcs_accept_unitless_all, \ + _compat_funcs_match_unit_all, \ + _compat_funcs_remove_unit_all, \ + _compat_get_attribute_all, \ + _compat_funcs_bit_operation_all, \ + _compat_funcs_logic_all, \ + _compat_funcs_indexing_all, \ + _compat_funcs_window_all, \ + _compat_linear_algebra_all, \ + _compat_misc_all, \ + _other_all diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 9080502..4feb08d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -436,18 +436,18 @@ def asarray( from builtins import all as origin_all from builtins import any as origin_any if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), unit=a.unit) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): return jnp.asarray(a, dtype=dtype, order=order) # list[Quantity] elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): # check all elements have the same unit - if origin_any(x.unit != a[0].unit for x in a): + if origin_any(x.dim != a[0].dim for x in a): raise ValueError('Units do not match for asarray operation.') values = [x.value for x in a] - unit = a[0].unit + unit = a[0].dim # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), unit=unit) + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: return jnp.asarray(a, dtype=dtype, order=order) @@ -501,7 +501,7 @@ def arange(*args, **kwargs): if stop is None: raise TypeError("Missing stop argument.") if stop is not None and not is_unitless(stop): - start = Quantity(start, unit=stop.unit) + start = Quantity(start, dim=stop.dim) fail_for_dimension_mismatch( start, @@ -533,7 +533,7 @@ def arange(*args, **kwargs): step=step.value if isinstance(step, Quantity) else jnp.asarray(step), **kwargs, ), - unit=unit, + dim=unit, ) else: return Quantity( @@ -543,7 +543,7 @@ def arange(*args, **kwargs): step=step.value if isinstance(step, Quantity) else jnp.asarray(step), **kwargs, ), - unit=unit, + dim=unit, ) @@ -575,12 +575,12 @@ def linspace(start: Union[Quantity, bst.typing.ArrayLike], start=start, stop=stop, ) - unit = getattr(start, "unit", DIMENSIONLESS) + unit = getattr(start, "dim", DIMENSIONLESS) start = start.value if isinstance(start, Quantity) else start stop = stop.value if isinstance(stop, Quantity) else stop result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) @set_module_as('brainunit.math') @@ -611,12 +611,12 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], start=start, stop=stop, ) - unit = getattr(start, "unit", DIMENSIONLESS) + unit = getattr(start, "dim", DIMENSIONLESS) start = start.value if isinstance(start, Quantity) else start stop = stop.value if isinstance(stop, Quantity) else stop result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) @set_module_as('brainunit.math') @@ -638,7 +638,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], ''' if isinstance(a, Quantity) and isinstance(val, Quantity): fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) elif is_unitless(a) or is_unitless(val): @@ -663,7 +663,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. ''' if isinstance(ary, Quantity): - return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] elif isinstance(ary, bst.typing.ArrayLike): return jnp.array_split(ary, indices_or_sections, axis) else: @@ -690,7 +690,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), unit=xi[0].unit) + return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) else: @@ -713,7 +713,7 @@ def vander(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): - return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) + return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)): return jnp.vander(x, N=N, increasing=increasing) else: diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index dbfca5e..c4a7c26 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -777,7 +777,7 @@ def wrap_function_to_method(func): @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) else: return func(x, *args, **kwargs) diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index ced279e..e649b14 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -49,7 +49,7 @@ def decorator(func: Callable) -> Callable: @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dtype=change_unit_func(x.dim))) elif isinstance(x, (jnp.ndarray, np.ndarray)): return func(x, *args, **kwargs) else: @@ -298,16 +298,16 @@ def decorator(func: Callable) -> Callable: def f(x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.dim, y.dim)) ) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return func(x, y, *args, **kwargs) elif isinstance(x, Quantity): return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.dim, DIMENSIONLESS))) elif isinstance(y, Quantity): return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.dim))) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') @@ -443,13 +443,13 @@ def power(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.power(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.power(x.value, y), dim=x.dim ** y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) + return _return_check_unitless(Quantity(jnp.power(x, y.value), dim=x ** y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') @@ -468,13 +468,13 @@ def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.floor_divide(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), dim=x.dim / y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), dim=x / y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') @@ -495,7 +495,7 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike], if isinstance(y, Quantity): assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' if isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), dim=x.dim ** y)) elif isinstance(x, (jax.Array, np.ndarray)): return jnp.float_power(x, y) else: @@ -516,12 +516,12 @@ def remainder(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.remainder(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), dim=x.dim % y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), dim=x % y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py index bf21d75..7f8d8fc 100644 --- a/brainunit/math/_compat_numpy_funcs_indexing.py +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -61,7 +61,7 @@ def where(condition: Union[bool, bst.typing.ArrayLike], # as both arguments have the same unit, just use the first one's dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] return Quantity.with_units( - jnp.where(condition, *dimensionless_args), args[0].unit + jnp.where(condition, *dimensionless_args), args[0].dim ) else: # illegal number of arguments @@ -155,11 +155,11 @@ def select(condlist: list[Union[bst.typing.ArrayLike]], from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.unit != choicelist[0].unit for choice in choicelist): + if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): raise ValueError("All choices must have the same unit") else: return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), - unit=choicelist[0].unit) + dim=choicelist[0].dim) elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): return jnp.select(condlist, choicelist, default=default) else: diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index b11f4c4..4a6616e 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -49,7 +49,7 @@ def wrap_math_funcs_keep_unit_unary(func): @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)): return func(x, *args, **kwargs) else: @@ -578,7 +578,7 @@ def wrap_math_funcs_keep_unit_binary(func): @wraps(func) def f(x1, x2, *args, **kwargs): if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): return func(x1, x2, *args, **kwargs) else: @@ -775,7 +775,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], ''' unit = None if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim if isinstance(x, Quantity): x_value = x.value else: @@ -790,7 +790,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], fp_value = fp result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result @@ -812,7 +812,7 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike], ''' unit = None if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim if isinstance(a, Quantity): a_value = a.value else: @@ -827,6 +827,6 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike], a_max_value = a_max result = jnp.clip(a_value, a_min_value, a_max_value) if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py index b863d87..d9926ad 100644 --- a/brainunit/math/_compat_numpy_funcs_match_unit.py +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -38,17 +38,17 @@ def wrap_math_funcs_match_unit_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return func(x, y, *args, **kwargs) elif isinstance(x, Quantity): if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') elif isinstance(y, Quantity): if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') else: diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index cebb5aa..4a26216 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -81,9 +81,9 @@ def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quan from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.unit != args[0].unit for arg in args): + if origin_any(arg.dim != args[0].dim for arg in args): raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit) + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): return jnp.broadcast_arrays(*args) else: @@ -151,7 +151,7 @@ def einsum( if contractions[i][4] == 'False': fail_for_dimension_mismatch( - Quantity([], unit=unit), operands[i + 1], 'einsum' + Quantity([], dim=unit), operands[i + 1], 'einsum' ) elif contractions[i][4] == 'DOT' or \ contractions[i][4] == 'TDOT' or \ @@ -159,14 +159,14 @@ def einsum( contractions[i][4] == 'OUTER/EINSUM': if i == 0: if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): - unit = operands[i].unit * operands[i + 1].unit + unit = operands[i].dim * operands[i + 1].dim elif isinstance(operands[i], Quantity): - unit = operands[i].unit + unit = operands[i].dim elif isinstance(operands[i + 1], Quantity): - unit = operands[i + 1].unit + unit = operands[i + 1].dim else: if isinstance(operands[i + 1], Quantity): - unit = unit * operands[i + 1].unit + unit = unit * operands[i + 1].dim contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) @@ -177,7 +177,7 @@ def einsum( r = einsum(operands, contractions, precision, # type: ignore[operator] preferred_element_type, _dot_general) if unit is not None: - return Quantity(r, unit=unit) + return Quantity(r, dim=unit) else: return r @@ -206,7 +206,7 @@ def gradient( if len(varargs) == 0: if isinstance(f, Quantity) and not is_unitless(f): - return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit) + return Quantity(jnp.gradient(f.value, axis=axis), dim=f.dim) else: return jnp.gradient(f) elif len(varargs) == 1: @@ -214,13 +214,13 @@ def gradient( if unit is None or unit == DIMENSIONLESS: return jnp.gradient(f, varargs[0], axis=axis) else: - return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] else: unit_list = [get_unit(f) / get_unit(v) for v in varargs] f = f.value if isinstance(f, Quantity) else f varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] result_list = jnp.gradient(f, *varargs, axis=axis) - return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] @set_module_as('brainunit.math') @@ -251,12 +251,12 @@ def intersect1d( result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) if return_indices: if unit is not None: - return (Quantity(result[0], unit=unit), result[1], result[2]) + return (Quantity(result[0], dim=unit), result[1], result[2]) else: return result else: if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result