Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 14, 2024
1 parent 8358bf8 commit 36d284d
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 212 deletions.
4 changes: 0 additions & 4 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def eye(
k: int = 0,
dtype: Optional[Any] = None,
unit: Optional[Unit] = None,
order: str = 'C',
) -> Union[Array, Quantity]:
"""
Returns a 2-D quantity or array of `shape` and `unit` with ones on the diagonal and zeros elsewhere.
Expand All @@ -95,9 +94,6 @@ def eye(
Data-type of the returned array.
unit : Unit, optional
Unit of the returned Quantity.
order : {'C', 'F'}, optional
Whether the output should be stored in row-major (C-style) or
column-major (Fortran-style) order in memory.
Returns
-------
Expand Down
64 changes: 28 additions & 36 deletions brainunit/math/_compat_numpy_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import tree_map

from brainunit._misc import set_module_as
from .._base import Quantity
Expand All @@ -35,9 +36,9 @@
'diagflat', 'diagonal', 'choose', 'ravel',
]


# array manipulation
# ------------------
from jax.tree_util import tree_map


def _as_jax_array_(obj):
Expand All @@ -48,7 +49,7 @@ def _is_leaf(a):
return isinstance(a, Quantity)


def func_array_manipulation(fun, *args, return_quantity=True, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]:
def func_array_manipulation(fun, *args, return_quantity=True, **kwargs) -> Any:
unit = None
if isinstance(args[0], Quantity):
unit = args[0].dim
Expand Down Expand Up @@ -228,7 +229,7 @@ def swapaxes(
def concatenate(
arrays: Union[Sequence[Array], Sequence[Quantity]],
axis: Optional[int] = None,
dtype: Optional[Any] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Array, Quantity]:
"""
Join a sequence of quantities or arrays along an existing axis.
Expand Down Expand Up @@ -258,7 +259,7 @@ def stack(
arrays: Union[Sequence[Array], Sequence[Quantity]],
axis: int = 0,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
dtype: Optional[Any] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Array, Quantity]:
"""
Join a sequence of quantities or arrays along a new axis.
Expand Down Expand Up @@ -288,7 +289,7 @@ def stack(
@set_module_as('brainunit.math')
def vstack(
tup: Union[Sequence[Array], Sequence[Quantity]],
dtype: Optional[Any] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Array, Quantity]:
"""
Stack quantities or arrays in sequence vertically (row wise).
Expand All @@ -315,7 +316,7 @@ def vstack(
@set_module_as('brainunit.math')
def hstack(
arrays: Union[Sequence[Array], Sequence[Quantity]],
dtype: Optional[Any] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Array, Quantity]:
"""
Stack quantities arrays in sequence horizontally (column wise).
Expand All @@ -339,7 +340,7 @@ def hstack(
@set_module_as('brainunit.math')
def dstack(
arrays: Union[Sequence[Array], Sequence[Quantity]],
dtype: Optional[Any] = None
dtype: Optional[jax.typing.DTypeLike] = None
) -> Union[Array, Quantity]:
"""
Stack quantities or arrays in sequence depth wise (along third axis).
Expand Down Expand Up @@ -587,7 +588,8 @@ def unique(
res : ndarray, Quantity
The sorted unique values.
"""
return func_array_manipulation(jnp.unique, a,
return func_array_manipulation(jnp.unique,
a,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
Expand Down Expand Up @@ -897,15 +899,19 @@ def argsort(
res : ndarray
Array of indices that sort the array.
"""
return func_array_manipulation(jnp.argsort, a, axis=axis, kind=kind, order=order, stable=stable,
return func_array_manipulation(jnp.argsort,
a,
axis=axis,
kind=kind,
order=order,
stable=stable,
descending=descending)


@set_module_as('brainunit.math')
def max(
a: Union[Array, Quantity],
axis: Optional[int] = None,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float]] = None,
where: Optional[Array] = None,
Expand All @@ -919,9 +925,6 @@ def max(
Array or quantity containing numbers whose maximum is desired.
axis : int or None, optional
Axis or axes along which to operate. By default, flattened input is used.
out : ndarray, Quantity, or None, optional
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not
provided or None, a freshly-allocated array is returned.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -938,14 +941,13 @@ def max(
Maximum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of
dimension `a.ndim - 1`.
"""
return func_array_manipulation(jnp.max, a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
return func_array_manipulation(jnp.max, a, axis=axis, keepdims=keepdims, initial=initial, where=where)


@set_module_as('brainunit.math')
def min(
a: Union[Array, Quantity],
axis: Optional[int] = None,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float]] = None,
where: Optional[Array] = None,
Expand All @@ -959,9 +961,6 @@ def min(
Array or quantity containing numbers whose minimum is desired.
axis : int or None, optional
Axis or axes along which to operate. By default, flattened input is used.
out : ndarray, Quantity, or None, optional
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not
provided or None, a freshly-allocated array is returned.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -978,14 +977,13 @@ def min(
Minimum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of
dimension `a.ndim - 1`.
"""
return func_array_manipulation(jnp.min, a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where)
return func_array_manipulation(jnp.min, a, axis=axis, keepdims=keepdims, initial=initial, where=where)


@set_module_as('brainunit.math')
def choose(
a: Union[Array, Quantity],
choices: Sequence[Union[Array, Quantity]],
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
mode: str = 'raise',
) -> Union[Array, Quantity]:
"""
Expand All @@ -998,8 +996,6 @@ def choose(
from `choices`.
choices : sequence of array_like, Quantity
Choice arrays. `a` and all `choices` must be broadcastable to the same shape.
out : ndarray, Quantity, or None, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
mode : {'raise', 'wrap', 'clip'}, optional
Specifies how indices outside [0, n-1] will be treated:
- 'raise' : raise an error (default)
Expand All @@ -1011,7 +1007,7 @@ def choose(
res : ndarray, Quantity
The constructed array. The shape is identical to the shape of `a`, and the data type is the data type of `choices`.
"""
return func_array_manipulation(jnp.choose, a, choices, out=out, mode=mode)
return func_array_manipulation(jnp.choose, a, choices, mode=mode)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -1043,7 +1039,6 @@ def compress(
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = 0,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
) -> Union[Array, Quantity]:
"""
Return selected slices of a quantity or an array along given axis.
Expand All @@ -1063,15 +1058,13 @@ def compress(
fill_value : scalar, optional
The value to use for elements in the output array that are not selected. If None, the output array has the same
type as `a` and is filled with zeros.
out : ndarray, Quantity, or None, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
Returns
-------
res : ndarray, Quantity
A new array that has the same number of dimensions as `a`, and the same shape as `a` with axis `axis` removed.
"""
return func_array_manipulation(jnp.compress, condition, a, axis, size=size, fill_value=fill_value, out=out)
return func_array_manipulation(jnp.compress, condition, a, axis, size=size, fill_value=fill_value)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -1103,7 +1096,6 @@ def diagflat(
def argmax(
a: Union[Array, Quantity],
axis: Optional[int] = None,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
keepdims: Optional[bool] = None
) -> Array:
"""
Expand All @@ -1115,8 +1107,6 @@ def argmax(
Input data.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array_like, Quantity, or None, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -1126,14 +1116,13 @@ def argmax(
res : ndarray
Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed.
"""
return func_array_manipulation(jnp.argmax, a, axis=axis, out=out, keepdim=keepdims, return_quantity=False)
return func_array_manipulation(jnp.argmax, a, axis=axis, keepdim=keepdims, return_quantity=False)


@set_module_as('brainunit.math')
def argmin(
a: Union[Array, Quantity],
axis: Optional[int] = None,
out: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
keepdims: Optional[bool] = None
) -> Array:
"""
Expand All @@ -1145,8 +1134,6 @@ def argmin(
Input data.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array_like, Quantity, or None, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
Expand All @@ -1156,7 +1143,7 @@ def argmin(
res : ndarray
Array of indices into the array. It has the same shape as `a.shape` with the dimension along `axis` removed.
"""
return func_array_manipulation(jnp.argmin, a, axis=axis, out=out, keepdims=keepdims, return_quantity=False)
return func_array_manipulation(jnp.argmin, a, axis=axis, keepdims=keepdims, return_quantity=False)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -1293,7 +1280,7 @@ def extract(
arr: Union[Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = 1,
fill_value: Optional[jax.typing.ArrayLike] = 0,
) -> Array:
"""
Return the elements of an array that satisfy some condition.
Expand All @@ -1304,6 +1291,11 @@ def extract(
An array of boolean values that selects which elements to extract.
arr : array_like, Quantity
The array from which to extract elements.
size: int
optional static size for output. Must be specified in order for ``extract``
to be compatible with JAX transformations like :func:`~jax.jit` or :func:`~jax.vmap`.
fill_value: array_like
if ``size`` is specified, fill padded entries with this value (default: 0).
Returns
-------
Expand Down
5 changes: 2 additions & 3 deletions brainunit/math/_compat_numpy_funcs_bit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .._base import Quantity

__all__ = [

# Elementwise bit operations (unary)
'bitwise_not', 'invert',

Expand All @@ -37,7 +36,7 @@

def elementwise_bit_operation_unary(func, x, *args, **kwargs):
if isinstance(x, Quantity):
raise ValueError(f'Expected integers, got {x}')
raise ValueError(f'Expected arrays, got {x}')
elif isinstance(x, (jax.Array, np.ndarray)):
return func(x, *args, **kwargs)
else:
Expand Down Expand Up @@ -86,7 +85,7 @@ def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array:

def elementwise_bit_operation_binary(func, x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
raise ValueError(f'Expected array, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)):
return func(x, y, *args, **kwargs)
else:
Expand Down
Loading

0 comments on commit 36d284d

Please sign in to comment.