From 1b0d380c6513e9c6f4a3e88abf3130ab3022ad94 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 22:33:15 +0800 Subject: [PATCH] Update array creation funcs --- .../math/_compat_numpy_array_creation.py | 301 +++++++++++------- brainunit/math/_compat_numpy_test.py | 2 +- 2 files changed, 187 insertions(+), 116 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index f2d7527..32219a8 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -39,76 +39,14 @@ ] -def wrap_array_creation_function(func: Callable) -> Callable: - @wraps(func) - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_array_creation_function +@set_module_as('brainunit.math') def full( shape: Sequence[int], fill_value: Any, dtype: Optional[Any] = None, unit: Optional[Unit] = None ) -> Union[Array, Quantity]: - return jnp.full(shape, fill_value, dtype=dtype) - - -@wrap_array_creation_function -def eye(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.eye(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def identity(n: int, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.identity(n, dtype=dtype) - - -@wrap_array_creation_function -def tri(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.tri(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def empty(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.empty(shape, dtype=dtype) - - -@wrap_array_creation_function -def ones(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.ones(shape, dtype=dtype) - - -@wrap_array_creation_function -def zeros(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.zeros(shape, dtype=dtype) - - -full.__doc__ = ''' + ''' Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. else return an array of `shape` filled with `fill_value`. @@ -125,8 +63,22 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.full(shape, fill_value, dtype=dtype) * unit + else: + return jnp.full(shape, fill_value, dtype=dtype) -eye.__doc__ = """ + +@set_module_as('brainunit.math') +def eye( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -144,9 +96,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.eye(N, M, k, dtype=dtype) * unit + else: + return jnp.eye(N, M, k, dtype=dtype) + -identity.__doc__ = """ +@set_module_as('brainunit.math') +def identity( + n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -161,9 +125,23 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.identity(n, dtype=dtype) * unit + else: + return jnp.identity(n, dtype=dtype) + -tri.__doc__ = """ +@set_module_as('brainunit.math') +def tri( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. else return a triangular matrix of `shape`. @@ -182,10 +160,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.tri(N, M, k, dtype=dtype) * unit + else: + return jnp.tri(N, M, k, dtype=dtype) -# empty -empty.__doc__ = """ + +@set_module_as('brainunit.math') +def empty( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. else return an array of `shape` with uninitialized values. @@ -200,10 +189,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.empty(shape, dtype=dtype) * unit + else: + return jnp.empty(shape, dtype=dtype) + -# ones -ones.__doc__ = """ +@set_module_as('brainunit.math') +def ones( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. else return an array of `shape` filled with 1. @@ -218,10 +218,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.ones(shape, dtype=dtype) * unit + else: + return jnp.ones(shape, dtype=dtype) -# zeros -zeros.__doc__ = """ + +@set_module_as('brainunit.math') +def zeros( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. else return an array of `shape` filled with 0. @@ -236,35 +247,40 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.zeros(shape, dtype=dtype) * unit + else: + return jnp.zeros(shape, dtype=dtype) @set_module_as('brainunit.math') def full_like(a: Union[Quantity, bst.typing.ArrayLike], - fill_value: Union[bst.typing.ArrayLike], - unit: Unit = None, + fill_value: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None) -> Union[Quantity, jax.Array]: ''' - Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity. else return an array of `a` filled with `fill_value`. Args: a: array_like, Quantity, shape, or dtype fill_value: scalar or array_like - unit: Unit, optional dtype: data-type, optional shape: sequence of ints, optional Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if isinstance(fill_value, Quantity): + return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), dim=fill_value.dim) else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) else: return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @@ -284,12 +300,19 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.diag(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) else: return jnp.diag(a, k=k) @@ -309,12 +332,19 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.tril(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) else: return jnp.tril(a, k=k) @@ -334,12 +364,19 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.triu(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) else: return jnp.triu(a, k=k) @@ -362,16 +399,24 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) else: return jnp.empty_like(a, dtype=dtype, shape=shape) + @set_module_as('brainunit.math') def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, @@ -390,12 +435,19 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) else: return jnp.ones_like(a, dtype=dtype, shape=shape) @@ -418,12 +470,19 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) else: return jnp.zeros_like(a, dtype=dtype, shape=shape) @@ -438,7 +497,8 @@ def asarray( ''' Convert the input to a quantity or array. - If unit is provided, the input is converted to a Quantity object with the given unit. + If unit is provided, the input will be checked whether it has the same unit as the provided unit. + If unit is not provided, the input will be converted to an array. Args: a: array_like, Quantity, or Sequence[Quantity] @@ -452,19 +512,30 @@ def asarray( if isinstance(a, Quantity): if unit is not None: assert isinstance(unit, Unit) - return jnp.asarray(a.value, dtype=dtype, order=order) * unit + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) else: - return jnp.asarray(a.value, dtype=dtype, order=order) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) return jnp.asarray(a, dtype=dtype, order=order) * unit else: return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence) and all(isinstance(x, Quantity) for x in a): + # check all elements have the same unit + if any(x.dim != a[0].dim for x in a): + raise ValueError('Units do not match for asarray operation.') + values = [x.value for x in a] + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: return jnp.asarray(a, dtype=dtype, order=order) + array = asarray diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 18cdd4a..2e73403 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -92,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4, unit=bu.second) + result_q = bu.math.full_like(q, 4 * bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self):