From 06b1d5f396e9e4c53dcd707cb4f0f1b10ddf18ee Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:55:37 +0800 Subject: [PATCH] [math] Update the logic of array creation fuctions (#7) * Update _compat_numpy.py * Update _compat_numpy.py * Update * Update _compat_numpy.py * Fix * Update brainunit.math.rst * Update _compat_numpy.py * Update _unit_test.py * Restruct * Update * Fix bugs * Fix bugs in Python 3.9 * Update _compat_numpy_funcs_bit_operation.py * Update _compat_numpy_funcs_bit_operation.py * Fix logic of `asarray` * update * Update array creation funcs * Update _compat_numpy_test.py * Add magnitude conversion for `asarray` * Update _compat_numpy_array_creation.py * Update _compat_numpy_test.py * Fix bugs --------- Co-authored-by: Chaoming Wang --- .../math/_compat_numpy_array_creation.py | 421 ++++++++++-------- brainunit/math/_compat_numpy_test.py | 15 +- 2 files changed, 249 insertions(+), 187 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 156a553..1e31f4d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -39,75 +39,14 @@ ] -def wrap_array_creation_function(func: Callable) -> Callable: - @wraps(func) - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_array_creation_function -def full(shape: Sequence[int], - fill_value: Any, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.full(shape, fill_value, dtype=dtype) - - -@wrap_array_creation_function -def eye(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.eye(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def identity(n: int, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.identity(n, dtype=dtype) - - -@wrap_array_creation_function -def tri(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.tri(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def empty(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.empty(shape, dtype=dtype) - - -@wrap_array_creation_function -def ones(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.ones(shape, dtype=dtype) - - -@wrap_array_creation_function -def zeros(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.zeros(shape, dtype=dtype) - - -full.__doc__ = ''' - Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. +@set_module_as('brainunit.math') +def full( + shape: Sequence[int], + fill_value: Union[Quantity, int, float], + dtype: Optional[Any] = None, +) -> Union[Array, Quantity]: + ''' + Returns a Quantity of `shape`, filled with `fill_value` if `fill_value` is a Quantity. else return an array of `shape` filled with `fill_value`. Args: @@ -115,16 +54,24 @@ def zeros(shape: Sequence[int], fill_value: the value to fill the new array with. dtype: the type of the output array, or `None`. If not `None`, `fill_value` will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' + if isinstance(fill_value, Quantity): + return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim) + return jnp.full(shape, fill_value, dtype=dtype) -eye.__doc__ = """ + +@set_module_as('brainunit.math') +def eye( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -135,16 +82,25 @@ def zeros(shape: Sequence[int], lower diagonal. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.eye(N, M, k, dtype=dtype) * unit + else: + return jnp.eye(N, M, k, dtype=dtype) + -identity.__doc__ = """ +@set_module_as('brainunit.math') +def identity( + n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -152,16 +108,27 @@ def zeros(shape: Sequence[int], n: the number of rows (and columns) in the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.identity(n, dtype=dtype) * unit + else: + return jnp.identity(n, dtype=dtype) + -tri.__doc__ = """ +@set_module_as('brainunit.math') +def tri( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. else return a triangular matrix of `shape`. @@ -173,17 +140,25 @@ def zeros(shape: Sequence[int], lower diagonal. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.tri(N, M, k, dtype=dtype) * unit + else: + return jnp.tri(N, M, k, dtype=dtype) -# empty -empty.__doc__ = """ + +@set_module_as('brainunit.math') +def empty( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. else return an array of `shape` with uninitialized values. @@ -191,17 +166,25 @@ def zeros(shape: Sequence[int], shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be of type `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.empty(shape, dtype=dtype) * unit + else: + return jnp.empty(shape, dtype=dtype) + -# ones -ones.__doc__ = """ +@set_module_as('brainunit.math') +def ones( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. else return an array of `shape` filled with 1. @@ -209,17 +192,25 @@ def zeros(shape: Sequence[int], shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.ones(shape, dtype=dtype) * unit + else: + return jnp.ones(shape, dtype=dtype) + -# zeros -zeros.__doc__ = """ +@set_module_as('brainunit.math') +def zeros( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. else return an array of `shape` filled with 0. @@ -227,50 +218,53 @@ def zeros(shape: Sequence[int], shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.zeros(shape, dtype=dtype) * unit + else: + return jnp.zeros(shape, dtype=dtype) @set_module_as('brainunit.math') def full_like(a: Union[Quantity, bst.typing.ArrayLike], - fill_value: Union[bst.typing.ArrayLike], - unit: Unit = None, + fill_value: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None) -> Union[Quantity, jax.Array]: ''' - Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity. else return an array of `a` filled with `fill_value`. Args: a: array_like, Quantity, shape, or dtype fill_value: scalar or array_like - unit: Unit, optional dtype: data-type, optional shape: sequence of ints, optional Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) + if isinstance(fill_value, Quantity): if isinstance(a, Quantity): - return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), dim=fill_value.dim) else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @set_module_as('brainunit.math') def diag(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Extract a diagonal or construct a diagonal array. @@ -282,12 +276,17 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.diag(a.value, k=k) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) else: return jnp.diag(a, k=k) @@ -295,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def tril(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Lower triangle of an array. @@ -307,12 +306,17 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.tril(a.value, k=k) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) else: return jnp.tril(a, k=k) @@ -320,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def triu(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Upper triangle of an array. @@ -332,12 +336,17 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.triu(a.value, k=k) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) else: return jnp.triu(a, k=k) @@ -346,7 +355,7 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], def empty_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. else return an array of `a` with uninitialized values. @@ -360,12 +369,17 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) else: return jnp.empty_like(a, dtype=dtype, shape=shape) @@ -374,7 +388,7 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. else return an array of `a` filled with 1. @@ -388,12 +402,17 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) else: return jnp.ones_like(a, dtype=dtype, shape=shape) @@ -402,7 +421,7 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. else return an array of `a` filled with 0. @@ -416,12 +435,17 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit - else: + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) else: return jnp.zeros_like(a, dtype=dtype, shape=shape) @@ -433,23 +457,56 @@ def asarray( order: Optional[str] = None, unit: Optional[Unit] = None, ) -> Union[Quantity, jax.Array]: - from builtins import all as origin_all - from builtins import any as origin_any + ''' + Convert the input to a quantity or array. + + If unit is provided, the input will be checked whether it has the same unit as the provided unit. + (If they have same dimension but different magnitude, the input will be converted to the provided unit.) + If unit is not provided, the input will be converted to an array. + + Args: + a: array_like, Quantity, or Sequence[Quantity] + dtype: data-type, optional + order: {'C', 'F', 'A', 'K'}, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.asarray(a, dtype=dtype, order=order) - # list[Quantity] - elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if origin_any(x.dim != a[0].dim for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].dim - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + if unit is not None: + assert isinstance(unit, Unit) + return jnp.asarray(a, dtype=dtype, order=order) * unit + else: + return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence): + leaves, tree = jax.tree.flatten(a, is_leaf=lambda x: isinstance(x, Quantity)) + if all([isinstance(leaf, Quantity) for leaf in leaves]): + # check all elements have the same unit + if any(x.dim != leaves[0].dim for x in leaves): + raise ValueError('Units do not match for asarray operation.') + values = jax.tree.unflatten(tree, [x.value for x in a]) + + fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.") + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + else: + values = jax.tree.unflatten(tree, leaves) + val = jnp.asarray(values, dtype=dtype, order=order) + if unit is not None: + assert isinstance(unit, Unit) + return val * unit + else: + return val else: - return jnp.asarray(a, dtype=dtype, order=order) + raise TypeError('Invalid input type for asarray.') array = asarray @@ -548,12 +605,14 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], - num: int = 50, - endpoint: Optional[bool] = True, - retstep: Optional[bool] = False, - dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: +def linspace( + start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: int = 50, + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None +) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. @@ -623,7 +682,7 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], val: Union[Quantity, bst.typing.ArrayLike], wrap: Optional[bool] = False, - inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: + inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -631,20 +690,22 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], a: array_like, Quantity val: scalar, Quantity wrap: bool, optional - inplace: bool, optional + unit: Unit, optional Returns: Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. ''' - if isinstance(a, Quantity) and isinstance(val, Quantity): - fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - elif is_unitless(a) or is_unitless(val): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + if isinstance(val, Quantity): + if isinstance(a, Quantity): + fail_for_dimension_mismatch(a, val, error_message="Array and value have to have the same units.") + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap, inplace=inplace), dim=a.dim) + else: + return Quantity(jnp.fill_diagonal(a, val.value, wrap, inplace=inplace), dim=val.dim) else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') + if isinstance(a, Quantity): + return jnp.fill_diagonal(a.value, val, wrap, inplace=inplace) + else: + return jnp.fill_diagonal(a, val, wrap, inplace=inplace) @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 8e39796..615d720 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -32,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" + assert q.dim == unit.dim or q.dim == unit, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" @@ -45,7 +45,7 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) - q = bu.math.full(3, 4, unit=second) + q = bu.math.full(3, 4 * second) self.assertEqual(q.shape, (3,)) assert_quantity(q, result, second) @@ -92,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4, unit=bu.second) + result_q = bu.math.full_like(q, 4 * bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): @@ -159,9 +159,10 @@ def test_asarray(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3]))) - result_q = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second]) + result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) + def test_arange(self): result = bu.math.arange(5) self.assertEqual(result.shape, (5,)) @@ -171,7 +172,7 @@ def test_arange(self): assert_quantity(result_q, jnp.arange(5, step=1), bu.second) result_q = bu.math.arange(3 * bu.second, 9 * bu.second, 1 * bu.second) - assert_quantity(result_q, jnp.arange(3, 9, 1), bu.second) + assert_quantity(result_q, jnp.arange(3, 9, 1), bu.ms) def test_linspace(self): result = bu.math.linspace(0, 10, 5) @@ -191,11 +192,11 @@ def test_logspace(self): def test_fill_diagonal(self): array = jnp.zeros((3, 3)) - result = bu.math.fill_diagonal(array, 5, inplace=False) + result = bu.math.fill_diagonal(array, 5) self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]))) q = jnp.zeros((3, 3)) * bu.second - result_q = bu.math.fill_diagonal(q, 5 * bu.second, inplace=False) + result_q = bu.math.fill_diagonal(q, 5 * bu.second) assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), bu.second) def test_array_split(self):