diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 5291ef9..16aa807 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -1168,7 +1168,7 @@ def intersect1d( ar2: Union[jax.Array, np.ndarray], assume_unique: bool = False, return_indices: bool = False -) -> Union[jax.Array, Quantity, tuple[jax.Array | Quantity, jax.Array, jax.Array]]: +) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') unit = None if isinstance(ar1, Quantity):