Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 13, 2024
1 parent 549fed7 commit 825cd40
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 47 deletions.
70 changes: 26 additions & 44 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from jax import Array

from brainunit._misc import set_module_as
from .. import Quantity
from .._base import (
DIMENSIONLESS,
Quantity,
Expand All @@ -45,7 +44,6 @@ def full(
shape: Sequence[int],
fill_value: Union[Quantity, int, float],
dtype: Optional[Any] = None,
order: Optional[str] = 'C',
) -> Union[Array, Quantity]:
"""
Returns a quantity of `shape`, filled with `fill_value` if `fill_value` is a Quantity.
Expand All @@ -59,9 +57,6 @@ def full(
Fill value.
dtype : data-type, optional
The desired data-type for the array The default, None, means ``np.array(fill_value).dtype`
order : {'C', 'F'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory.
Returns
-------
Expand All @@ -70,8 +65,8 @@ def full(
Array of `fill_value` with the given shape, dtype, and order.
"""
if isinstance(fill_value, Quantity):
return Quantity(jnp.full(shape, fill_value.value, dtype=dtype, order=order), dim=fill_value.dim)
return jnp.full(shape, fill_value, dtype=dtype, order=order)
return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim)
return jnp.full(shape, fill_value, dtype=dtype)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -351,7 +346,7 @@ def diag(
if unit is not None:
assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}'
fail_for_dimension_mismatch(v, unit, error_message="a and unit have to have the same units.")
return Quantity(jnp.diag(a.value, k=k), dim=a.dim)
return Quantity(jnp.diag(v.value, k=k), dim=v.dim)
elif isinstance(v, (jax.Array, np.ndarray)):
if unit is not None:
assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}'
Expand Down Expand Up @@ -643,51 +638,43 @@ def arange(
Parameters
----------
start : Quantity or array, optional
Start of the interval. The interval includes this value. The default start value is 0.
Start of the interval. The interval includes this value. The default start value is 0.
stop : Quantity or array
End of the interval. The interval does not include this value, except in some cases where `step` is not an integer
and floating point round-off affects the length of `out`.
End of the interval. The interval does not include this value, except in some cases where `step` is not an integer
and floating point round-off affects the length of `out`.
step : Quantity or array, optional
Spacing between values. For any output `out`, this is the distance between two adjacent values, `out[i+1] - out[i]`.
The default step size is 1.
Spacing between values. For any output `out`, this is the distance between two adjacent values, `out[i+1] - out[i]`.
The default step size is 1.
dtype : data-type, optional
The type of the output array. If `dtype` is not given, infer the data type from the other input arguments.
The type of the output array. If `dtype` is not given, infer the data type from the other input arguments.
Returns
-------
out : quantity or array
Array of evenly spaced values.
Array of evenly spaced values.
"""
# arange has a bit of a complicated argument structure unfortunately
# we leave the actual checking of the number of arguments to numpy, though

# default values
arg_len = len([x for x in [start, stop, step] if x is not None])
if start is None:
start = 0
if step is None:
step = 1

if arg_len == 1:
if stop is not None:
raise TypeError("Duplicate definition of 'stop'")
stop = start
start = 0
elif arg_len == 2:
if start != 0:
raise TypeError("Duplicate definition of 'start'")
if stop is not None:
raise TypeError("Duplicate definition of 'stop'")
start, stop = start, stop
elif arg_len == 3:
if start != 0:
raise TypeError("Duplicate definition of 'start'")
if stop is not None:
raise TypeError("Duplicate definition of 'stop'")
if step != 1:
raise TypeError("Duplicate definition of 'step'")
start, stop, step = start, stop, step
if start is not None and stop is None:
stop = start
start = 0

elif arg_len > 3:
raise TypeError("Need between 1 and 3 non-keyword arguments")

# default values
if start is None:
start = 0
if step is None:
step = 1

if stop is None:
raise TypeError("Missing stop argument.")
if stop is not None and not is_unitless(stop):
Expand All @@ -696,25 +683,20 @@ def arange(
fail_for_dimension_mismatch(
start,
stop,
error_message=(
"Start value {start} and stop value {stop} have to have the same units."
),
error_message="Start value {start} and stop value {stop} have to have the same units.",
start=start,
stop=stop,
)
fail_for_dimension_mismatch(
stop,
step,
error_message=(
"Stop value {stop} and step value {step} have to have the same units."
),
error_message="Stop value {stop} and step value {step} have to have the same units.",
stop=stop,
step=step,
)

unit = getattr(stop, "dim", DIMENSIONLESS)
# start is a position-only argument in numpy 2.0
# https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only
# TODO: check whether this is still the case in the final release

if start == 0:
return Quantity(
jnp.arange(
Expand Down
9 changes: 6 additions & 3 deletions brainunit/math/_compat_numpy_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

from collections.abc import Sequence
from typing import (Union, Optional, Tuple, List)
from typing import (Union, Optional, Tuple, List, Any)

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -167,7 +167,9 @@ def transpose(

@set_module_as('brainunit.math')
def swapaxes(
a: Union[Array, Quantity], axis1: int, axis2: int
a: Union[Array, Quantity],
axis1: int,
axis2: int
) -> Union[Array, Quantity]:
"""
Interchanges two axes of an array.
Expand All @@ -186,7 +188,8 @@ def swapaxes(
@set_module_as('brainunit.math')
def concatenate(
arrays: Union[Sequence[Array], Sequence[Quantity]],
axis: Optional[int] = None
axis: Optional[int] = None,
dtype: Optional[Any] = None
) -> Union[Array, Quantity]:
"""
Join a sequence of arrays along an existing axis.
Expand Down

0 comments on commit 825cd40

Please sign in to comment.