From 9a27afefa4b5e97385333d41755da8d26cc1bee7 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 14 Jun 2024 09:33:24 +0800 Subject: [PATCH] Fix bugs in Python 3.9 --- .../math/_compat_numpy_funcs_keep_unit.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index 19dbbb7..b648d9c 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -347,11 +347,11 @@ def fix( def sum( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, keepdims: bool = False, - initial: jax.typing.ArrayLike | None = None, - where: jax.typing.ArrayLike | None = None, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, promote_integers: bool = True ) -> Union[Quantity, jax.Array]: """ @@ -410,7 +410,7 @@ def sum( def nancumsum( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, ) -> Union[Quantity, jax.Array]: """ @@ -447,11 +447,11 @@ def nancumsum( def nansum( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, keepdims: bool = False, - initial: jax.typing.ArrayLike | None = None, - where: jax.typing.ArrayLike | None = None, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, ) -> Union[Quantity, jax.Array]: """ Return the sum of the array elements, ignoring NaNs. @@ -503,7 +503,7 @@ def nansum( def cumsum( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, ) -> Union[Quantity, jax.Array]: """ @@ -648,8 +648,8 @@ def nanmin( axis: Union[int, Sequence[int], None] = None, out: None = None, keepdims: bool = False, - initial: jax.typing.ArrayLike | None = None, - where: jax.typing.ArrayLike | None = None, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, ) -> Union[Quantity, jax.Array]: """ Return the minimum of the array elements, ignoring NaNs. @@ -696,8 +696,8 @@ def nanmax( axis: Union[int, Sequence[int], None] = None, out: None = None, keepdims: bool = False, - initial: jax.typing.ArrayLike | None = None, - where: jax.typing.ArrayLike | None = None, + initial: Union[jax.typing.ArrayLike, None] = None, + where: Union[jax.typing.ArrayLike, None] = None, ) -> Union[Quantity, jax.Array]: """ Return the maximum of the array elements, ignoring NaNs. @@ -786,7 +786,7 @@ def ptp( def average( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - weights: jax.typing.ArrayLike | None = None, + weights: Union[jax.typing.ArrayLike, None] = None, returned: bool = False, keepdims: bool = False ) -> Union[Quantity, jax.Array]: @@ -840,10 +840,10 @@ def average( def mean( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, keepdims: bool = False, *, - where: jax.typing.ArrayLike | None = None + where: Union[jax.typing.ArrayLike, None] = None ) -> Union[Quantity, jax.Array]: """ Return the mean of the array elements. @@ -892,11 +892,11 @@ def mean( def std( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: jax.typing.ArrayLike | None = None + where: Union[jax.typing.ArrayLike, None] = None ) -> Union[Quantity, jax.Array]: """ Return the standard deviation of the array elements. @@ -948,7 +948,7 @@ def std( @set_module_as('brainunit.math') def nanmedian( x: Union[Quantity, jax.typing.ArrayLike], - axis: int | tuple[int, ...] | None = None, + axis: Union[int, tuple[int, ...], None] = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False ) -> Union[Quantity, jax.Array]: @@ -998,10 +998,10 @@ def nanmedian( def nanmean( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, keepdims: bool = False, *, - where: jax.typing.ArrayLike | None = None + where: Union[jax.typing.ArrayLike, None] = None ) -> Union[Quantity, jax.Array]: """ Return the mean of the array elements, ignoring NaNs. @@ -1050,11 +1050,11 @@ def nanmean( def nanstd( x: Union[Quantity, jax.typing.ArrayLike], axis: Union[int, Sequence[int], None] = None, - dtype: jax.typing.DTypeLike | None = None, + dtype: Union[jax.typing.DTypeLike, None] = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: jax.typing.ArrayLike | None = None + where: Union[jax.typing.ArrayLike, None] = None ) -> Union[Quantity, jax.Array]: """ Return the standard deviation of the array elements, ignoring NaNs. @@ -1108,8 +1108,8 @@ def nanstd( def diff( x: Union[Quantity, jax.typing.ArrayLike], n: int = 1, axis: int = -1, - prepend: jax.typing.ArrayLike | None = None, - append: jax.typing.ArrayLike | None = None + prepend: Union[jax.typing.ArrayLike, None] = None, + append: Union[jax.typing.ArrayLike, None] = None ) -> Union[Quantity, jax.Array]: """ Return the differences between consecutive elements of the array.