Skip to content

Commit

Permalink
implement boolean ops for function.Array
Browse files Browse the repository at this point in the history
This patch implements boolean not, and, or, all and any for `function.Array`.
  • Loading branch information
joostvanzwieten committed Dec 19, 2023
1 parent efb4698 commit f30a2c0
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
52 changes: 50 additions & 2 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,10 +2846,11 @@ def Namespace(*args, **kwargs):

class __implementations__:

def implements(np_function):
def implements(*np_functions):
'Register an ``__array_function__`` or ``__array_ufunc__`` implementation for Array objects.'
def decorator(func):
HANDLED_FUNCTIONS[np_function] = func
for np_function in np_functions:
HANDLED_FUNCTIONS[np_function] = func
return func
return decorator

Expand Down Expand Up @@ -3045,6 +3046,53 @@ def maximum(a: IntoArray, b: IntoArray) -> Array:
raise ValueError('Complex numbers have no total order.')
return _Wrapper.broadcasted_arrays(evaluable.Maximum, a, b)

@implements(numpy.logical_and, numpy.bitwise_and)
def logical_and(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented
return _Wrapper.broadcasted_arrays(evaluable.multiply, a, b)

@implements(numpy.logical_or, numpy.bitwise_or)
def logical_or(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented
return _Wrapper.broadcasted_arrays(evaluable.add, a, b)

@implements(numpy.logical_not, numpy.invert)
def logical_not(a: IntoArray) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented
return _Wrapper.broadcasted_arrays(evaluable.LogicalNot, a, force_dtype=bool)

@implements(numpy.all)
def all(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented
return _Wrapper(evaluable.Product, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.any)
def any(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented
return _Wrapper(evaluable.Sum, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.sum)
def sum(arg: IntoArray, axis: Optional[Union[int, Sequence[int]]] = None) -> Array:
arg = Array.cast(arg)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ def _check(name, op, n_op, *args):
_check('max', lambda a, b: numpy.maximum(a, function.Array.cast(b)), numpy.maximum, ANY(4, 1), ANY(1, 4))
_check('heaviside', function.heaviside, lambda u: numpy.heaviside(u, .5), ANY(4, 4))

_check('logical_and', lambda a, b: numpy.logical_and(a, function.Array.cast(b)), numpy.logical_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_and-bool', lambda a, b: numpy.bitwise_and(a, function.Array.cast(b)), numpy.bitwise_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_or', lambda a, b: numpy.logical_or(a, function.Array.cast(b)), numpy.logical_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_or-bool', lambda a, b: numpy.bitwise_or(a, function.Array.cast(b)), numpy.bitwise_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_not', lambda a: numpy.logical_not(function.Array.cast(a)), numpy.logical_not, numpy.array([[False, True], [True, False], [True, True]]))
_check('invert-bool', lambda a: numpy.invert(function.Array.cast(a)), numpy.invert, numpy.array([[False, True], [True, False], [True, True]]))
_check('all-bool-all-axes', lambda a: numpy.all(function.Array.cast(a)), numpy.all, numpy.array([[False, True], [True, True]]))
_check('all-bool-single-axis', lambda a: numpy.all(function.Array.cast(a), axis=0), lambda a: numpy.all(a, axis=0), numpy.array([[False, True], [True, True]]))
_check('any-bool-all-axes', lambda a: numpy.any(function.Array.cast(a)), numpy.any, numpy.array([[False, True], [True, True]]))
_check('any-bool-single-axis', lambda a: numpy.any(function.Array.cast(a), axis=0), lambda a: numpy.any(a, axis=0), numpy.array([[False, True], [False, False]]))

## TODO: opposite
## TODO: mean
## TODO: jump
Expand Down

0 comments on commit f30a2c0

Please sign in to comment.