Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update
Browse files Browse the repository at this point in the history
Routhleck committed Jun 11, 2024
1 parent 229ad90 commit 337b365
Showing 8 changed files with 70 additions and 58 deletions.
18 changes: 15 additions & 3 deletions brainunit/math/__init__.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,6 @@
from ._compat_numpy_misc import *
from ._compat_numpy_misc import __all__ as _compat_misc_all


__all__ = _compat_array_creation_all + \
_compat_array_manipulation_all + \
_compat_funcs_change_unit_all + \
@@ -63,5 +62,18 @@
_compat_misc_all + _other_all + \
_other_all

del _compat_array_creation_all, _compat_array_manipulation_all, _compat_funcs_change_unit_all, _compat_funcs_keep_unit_all, _compat_funcs_accept_unitless_all, _compat_funcs_match_unit_all, _compat_funcs_remove_unit_all, _compat_get_attribute_all, _compat_funcs_bit_operation_all, _compat_funcs_logic_all, _compat_funcs_indexing_all, _compat_funcs_window_all, _compat_linear_algebra_all, _compat_misc_all, _other_all

del _compat_array_creation_all, \
_compat_array_manipulation_all, \
_compat_funcs_change_unit_all, \
_compat_funcs_keep_unit_all, \
_compat_funcs_accept_unitless_all, \
_compat_funcs_match_unit_all, \
_compat_funcs_remove_unit_all, \
_compat_get_attribute_all, \
_compat_funcs_bit_operation_all, \
_compat_funcs_logic_all, \
_compat_funcs_indexing_all, \
_compat_funcs_window_all, \
_compat_linear_algebra_all, \
_compat_misc_all, \
_other_all
30 changes: 15 additions & 15 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
@@ -436,18 +436,18 @@ def asarray(
from builtins import all as origin_all
from builtins import any as origin_any
if isinstance(a, Quantity):
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), unit=a.unit)
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim)
elif isinstance(a, (jax.Array, np.ndarray)):
return jnp.asarray(a, dtype=dtype, order=order)
# list[Quantity]
elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a):
# check all elements have the same unit
if origin_any(x.unit != a[0].unit for x in a):
if origin_any(x.dim != a[0].dim for x in a):
raise ValueError('Units do not match for asarray operation.')
values = [x.value for x in a]
unit = a[0].unit
unit = a[0].dim
# Convert the values to a jnp.ndarray and create a Quantity object
return Quantity(jnp.asarray(values, dtype=dtype, order=order), unit=unit)
return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit)
else:
return jnp.asarray(a, dtype=dtype, order=order)

@@ -501,7 +501,7 @@ def arange(*args, **kwargs):
if stop is None:
raise TypeError("Missing stop argument.")
if stop is not None and not is_unitless(stop):
start = Quantity(start, unit=stop.unit)
start = Quantity(start, dim=stop.dim)

fail_for_dimension_mismatch(
start,
@@ -533,7 +533,7 @@ def arange(*args, **kwargs):
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
**kwargs,
),
unit=unit,
dim=unit,
)
else:
return Quantity(
@@ -543,7 +543,7 @@ def arange(*args, **kwargs):
step=step.value if isinstance(step, Quantity) else jnp.asarray(step),
**kwargs,
),
unit=unit,
dim=unit,
)


@@ -575,12 +575,12 @@ def linspace(start: Union[Quantity, bst.typing.ArrayLike],
start=start,
stop=stop,
)
unit = getattr(start, "unit", DIMENSIONLESS)
unit = getattr(start, "dim", DIMENSIONLESS)
start = start.value if isinstance(start, Quantity) else start
stop = stop.value if isinstance(stop, Quantity) else stop

result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype)
return Quantity(result, unit=unit)
return Quantity(result, dim=unit)


@set_module_as('brainunit.math')
@@ -611,12 +611,12 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike],
start=start,
stop=stop,
)
unit = getattr(start, "unit", DIMENSIONLESS)
unit = getattr(start, "dim", DIMENSIONLESS)
start = start.value if isinstance(start, Quantity) else start
stop = stop.value if isinstance(stop, Quantity) else stop

result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype)
return Quantity(result, unit=unit)
return Quantity(result, dim=unit)


@set_module_as('brainunit.math')
@@ -638,7 +638,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
'''
if isinstance(a, Quantity) and isinstance(val, Quantity):
fail_for_dimension_mismatch(a, val)
return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit)
return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim)
elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)):
return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace)
elif is_unitless(a) or is_unitless(val):
@@ -663,7 +663,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array.
'''
if isinstance(ary, Quantity):
return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)]
return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)]
elif isinstance(ary, bst.typing.ArrayLike):
return jnp.array_split(ary, indices_or_sections, axis)
else:
@@ -690,7 +690,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike],
from builtins import all as origin_all
if origin_all(isinstance(x, Quantity) for x in xi):
fail_for_dimension_mismatch(*xi)
return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), unit=xi[0].unit)
return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim)
elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi):
return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing)
else:
@@ -713,7 +713,7 @@ def vander(x: Union[Quantity, bst.typing.ArrayLike],
Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array.
'''
if isinstance(x, Quantity):
return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit)
return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim)
elif isinstance(x, (jax.Array, np.ndarray)):
return jnp.vander(x, N=N, increasing=increasing)
else:
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_array_manipulation.py
Original file line number Diff line number Diff line change
@@ -777,7 +777,7 @@ def wrap_function_to_method(func):
@wraps(func)
def f(x, *args, **kwargs):
if isinstance(x, Quantity):
return Quantity(func(x.value, *args, **kwargs), unit=x.unit)
return Quantity(func(x.value, *args, **kwargs), dim=x.dim)
else:
return func(x, *args, **kwargs)

28 changes: 14 additions & 14 deletions brainunit/math/_compat_numpy_funcs_change_unit.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def decorator(func: Callable) -> Callable:
@wraps(func)
def f(x, *args, **kwargs):
if isinstance(x, Quantity):
return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit)))
return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dtype=change_unit_func(x.dim)))
elif isinstance(x, (jnp.ndarray, np.ndarray)):
return func(x, *args, **kwargs)
else:
@@ -298,16 +298,16 @@ def decorator(func: Callable) -> Callable:
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) and isinstance(y, Quantity):
return _return_check_unitless(
Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit))
Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.dim, y.dim))
)
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return func(x, y, *args, **kwargs)
elif isinstance(x, Quantity):
return _return_check_unitless(
Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS)))
Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.dim, DIMENSIONLESS)))
elif isinstance(y, Quantity):
return _return_check_unitless(
Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit)))
Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.dim)))
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}')

@@ -443,13 +443,13 @@ def power(x: Union[Quantity, bst.typing.ArrayLike],
Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array.
'''
if isinstance(x, Quantity) and isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit))
return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim))
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return jnp.power(x, y)
elif isinstance(x, Quantity):
return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y))
return _return_check_unitless(Quantity(jnp.power(x.value, y), dim=x.dim ** y))
elif isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit))
return _return_check_unitless(Quantity(jnp.power(x, y.value), dim=x ** y.dim))
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}')

@@ -468,13 +468,13 @@ def floor_divide(x: Union[Quantity, bst.typing.ArrayLike],
Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array.
'''
if isinstance(x, Quantity) and isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit))
return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim))
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return jnp.floor_divide(x, y)
elif isinstance(x, Quantity):
return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y))
return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), dim=x.dim / y))
elif isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit))
return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), dim=x / y.dim))
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}')

@@ -495,7 +495,7 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike],
if isinstance(y, Quantity):
assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent'
if isinstance(x, Quantity):
return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y))
return _return_check_unitless(Quantity(jnp.float_power(x.value, y), dim=x.dim ** y))
elif isinstance(x, (jax.Array, np.ndarray)):
return jnp.float_power(x, y)
else:
@@ -516,12 +516,12 @@ def remainder(x: Union[Quantity, bst.typing.ArrayLike],
Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array.
'''
if isinstance(x, Quantity) and isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit))
return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim))
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return jnp.remainder(x, y)
elif isinstance(x, Quantity):
return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y))
return _return_check_unitless(Quantity(jnp.remainder(x.value, y), dim=x.dim % y))
elif isinstance(y, Quantity):
return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit))
return _return_check_unitless(Quantity(jnp.remainder(x, y.value), dim=x % y.dim))
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}')
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_funcs_indexing.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ def where(condition: Union[bool, bst.typing.ArrayLike],
# as both arguments have the same unit, just use the first one's
dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args]
return Quantity.with_units(
jnp.where(condition, *dimensionless_args), args[0].unit
jnp.where(condition, *dimensionless_args), args[0].dim
)
else:
# illegal number of arguments
@@ -155,11 +155,11 @@ def select(condlist: list[Union[bst.typing.ArrayLike]],
from builtins import all as origin_all
from builtins import any as origin_any
if origin_all(isinstance(choice, Quantity) for choice in choicelist):
if origin_any(choice.unit != choicelist[0].unit for choice in choicelist):
if origin_any(choice.dim != choicelist[0].dim for choice in choicelist):
raise ValueError("All choices must have the same unit")
else:
return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default),
unit=choicelist[0].unit)
dim=choicelist[0].dim)
elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist):
return jnp.select(condlist, choicelist, default=default)
else:
12 changes: 6 additions & 6 deletions brainunit/math/_compat_numpy_funcs_keep_unit.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def wrap_math_funcs_keep_unit_unary(func):
@wraps(func)
def f(x, *args, **kwargs):
if isinstance(x, Quantity):
return Quantity(func(x.value, *args, **kwargs), unit=x.unit)
return Quantity(func(x.value, *args, **kwargs), dim=x.dim)
elif isinstance(x, (jax.Array, np.ndarray)):
return func(x, *args, **kwargs)
else:
@@ -578,7 +578,7 @@ def wrap_math_funcs_keep_unit_binary(func):
@wraps(func)
def f(x1, x2, *args, **kwargs):
if isinstance(x1, Quantity) and isinstance(x2, Quantity):
return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit)
return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim)
elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)):
return func(x1, x2, *args, **kwargs)
else:
@@ -775,7 +775,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike],
'''
unit = None
if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity):
unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit
unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim
if isinstance(x, Quantity):
x_value = x.value
else:
@@ -790,7 +790,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike],
fp_value = fp
result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period)
if unit is not None:
return Quantity(result, unit=unit)
return Quantity(result, dim=unit)
else:
return result

@@ -812,7 +812,7 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike],
'''
unit = None
if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity):
unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit
unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim
if isinstance(a, Quantity):
a_value = a.value
else:
@@ -827,6 +827,6 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike],
a_max_value = a_max
result = jnp.clip(a_value, a_min_value, a_max_value)
if unit is not None:
return Quantity(result, unit=unit)
return Quantity(result, dim=unit)
else:
return result
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_funcs_match_unit.py
Original file line number Diff line number Diff line change
@@ -38,17 +38,17 @@ def wrap_math_funcs_match_unit_binary(func):
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) and isinstance(y, Quantity):
fail_for_dimension_mismatch(x, y)
return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit)
return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim)
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return func(x, y, *args, **kwargs)
elif isinstance(x, Quantity):
if x.is_unitless:
return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit)
return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim)
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}')
elif isinstance(y, Quantity):
if y.is_unitless:
return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit)
return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim)
else:
raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}')
else:
26 changes: 13 additions & 13 deletions brainunit/math/_compat_numpy_misc.py
Original file line number Diff line number Diff line change
@@ -81,9 +81,9 @@ def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quan
from builtins import all as origin_all
from builtins import any as origin_any
if origin_all(isinstance(arg, Quantity) for arg in args):
if origin_any(arg.unit != args[0].unit for arg in args):
if origin_any(arg.dim != args[0].dim for arg in args):
raise ValueError("All arguments must have the same unit")
return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit)
return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim)
elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args):
return jnp.broadcast_arrays(*args)
else:
@@ -151,22 +151,22 @@ def einsum(
if contractions[i][4] == 'False':

fail_for_dimension_mismatch(
Quantity([], unit=unit), operands[i + 1], 'einsum'
Quantity([], dim=unit), operands[i + 1], 'einsum'
)
elif contractions[i][4] == 'DOT' or \
contractions[i][4] == 'TDOT' or \
contractions[i][4] == 'GEMM' or \
contractions[i][4] == 'OUTER/EINSUM':
if i == 0:
if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity):
unit = operands[i].unit * operands[i + 1].unit
unit = operands[i].dim * operands[i + 1].dim
elif isinstance(operands[i], Quantity):
unit = operands[i].unit
unit = operands[i].dim
elif isinstance(operands[i + 1], Quantity):
unit = operands[i + 1].unit
unit = operands[i + 1].dim
else:
if isinstance(operands[i + 1], Quantity):
unit = unit * operands[i + 1].unit
unit = unit * operands[i + 1].dim

contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)

@@ -177,7 +177,7 @@ def einsum(
r = einsum(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)
if unit is not None:
return Quantity(r, unit=unit)
return Quantity(r, dim=unit)
else:
return r

@@ -206,21 +206,21 @@ def gradient(

if len(varargs) == 0:
if isinstance(f, Quantity) and not is_unitless(f):
return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit)
return Quantity(jnp.gradient(f.value, axis=axis), dim=f.dim)
else:
return jnp.gradient(f)
elif len(varargs) == 1:
unit = get_unit(f) / get_unit(varargs[0])
if unit is None or unit == DIMENSIONLESS:
return jnp.gradient(f, varargs[0], axis=axis)
else:
return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)]
return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)]
else:
unit_list = [get_unit(f) / get_unit(v) for v in varargs]
f = f.value if isinstance(f, Quantity) else f
varargs = [v.value if isinstance(v, Quantity) else v for v in varargs]
result_list = jnp.gradient(f, *varargs, axis=axis)
return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)]
return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)]


@set_module_as('brainunit.math')
@@ -251,12 +251,12 @@ def intersect1d(
result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
if return_indices:
if unit is not None:
return (Quantity(result[0], unit=unit), result[1], result[2])
return (Quantity(result[0], dim=unit), result[1], result[2])
else:
return result
else:
if unit is not None:
return Quantity(result, unit=unit)
return Quantity(result, dim=unit)
else:
return result

0 comments on commit 337b365

Please sign in to comment.