Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 12, 2024
1 parent 06b1d5f commit 6b43099
Show file tree
Hide file tree
Showing 12 changed files with 1,582 additions and 1,500 deletions.
822 changes: 451 additions & 371 deletions brainunit/math/_compat_numpy_array_manipulation.py

Large diffs are not rendered by default.

503 changes: 236 additions & 267 deletions brainunit/math/_compat_numpy_funcs_accept_unitless.py

Large diffs are not rendered by default.

107 changes: 59 additions & 48 deletions brainunit/math/_compat_numpy_funcs_bit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,45 +89,27 @@ def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array:

def wrap_elementwise_bit_operation_binary(func):
@wraps(func)
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')

f.__module__ = 'brainunit.math'
return f


@wrap_elementwise_bit_operation_binary
def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array:
return jnp.bitwise_and(x, y)


@wrap_elementwise_bit_operation_binary
def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array:
return jnp.bitwise_or(x, y)


@wrap_elementwise_bit_operation_binary
def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array:
return jnp.bitwise_xor(x, y)


@wrap_elementwise_bit_operation_binary
def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array:
return jnp.left_shift(x, y)


@wrap_elementwise_bit_operation_binary
def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array:
return jnp.right_shift(x, y)


# docs for functions above
bitwise_and.__doc__ = '''
def decorator(*args, **kwargs):
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')

f.__module__ = 'brainunit.math'
return f

return decorator


@wrap_elementwise_bit_operation_binary(jnp.bitwise_and)
def bitwise_and(
x: Union[Quantity, bst.typing.ArrayLike],
y: Union[Quantity, bst.typing.ArrayLike]
) -> Array:
'''
Compute the bit-wise AND of two arrays element-wise.
Args:
Expand All @@ -136,9 +118,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst
Returns:
jax.Array: an array
'''
'''
...


bitwise_or.__doc__ = '''
@wrap_elementwise_bit_operation_binary(jnp.bitwise_or)
def bitwise_or(
x: Union[Quantity, bst.typing.ArrayLike],
y: Union[Quantity, bst.typing.ArrayLike]
) -> Array:
'''
Compute the bit-wise OR of two arrays element-wise.
Args:
Expand All @@ -147,9 +136,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst
Returns:
jax.Array: an array
'''
'''
...

bitwise_xor.__doc__ = '''

@wrap_elementwise_bit_operation_binary(jnp.bitwise_xor)
def bitwise_xor(
x: Union[Quantity, bst.typing.ArrayLike],
y: Union[Quantity, bst.typing.ArrayLike]
) -> Array:
'''
Compute the bit-wise XOR of two arrays element-wise.
Args:
Expand All @@ -158,9 +154,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst
Returns:
jax.Array: an array
'''
'''
...


left_shift.__doc__ = '''
@wrap_elementwise_bit_operation_binary(jnp.left_shift)
def left_shift(
x: Union[Quantity, bst.typing.ArrayLike],
y: Union[Quantity, bst.typing.ArrayLike]
) -> Array:
'''
Shift the bits of an integer to the left.
Args:
Expand All @@ -169,9 +172,16 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst
Returns:
jax.Array: an array
'''
'''
...


right_shift.__doc__ = '''
@wrap_elementwise_bit_operation_binary(jnp.right_shift)
def right_shift(
x: Union[Quantity, bst.typing.ArrayLike],
y: Union[Quantity, bst.typing.ArrayLike]
) -> Array:
'''
Shift the bits of an integer to the right.
Args:
Expand All @@ -180,4 +190,5 @@ def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst
Returns:
jax.Array: an array
'''
'''
...
Loading

0 comments on commit 6b43099

Please sign in to comment.