Skip to content

Commit

Permalink
implement boolean ops for function.Array
Browse files Browse the repository at this point in the history
This patch implements boolean not, and, or, all and any for `function.Array`.
  • Loading branch information
joostvanzwieten committed Dec 19, 2023
1 parent c51b794 commit 75eb670
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
50 changes: 50 additions & 0 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3045,6 +3045,56 @@ def maximum(a: IntoArray, b: IntoArray) -> Array:
raise ValueError('Complex numbers have no total order.')
return _Wrapper.broadcasted_arrays(evaluable.Maximum, a, b)

@implements(numpy.logical_and)
@implements(numpy.bitwise_and)
def logical_and(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented

Check warning on line 3053 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3053

Added line #L3053 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.multiply, a, b)

@implements(numpy.logical_or)
@implements(numpy.bitwise_or)
def logical_or(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented

Check warning on line 3061 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3061

Added line #L3061 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.add, a, b)

@implements(numpy.logical_not)
@implements(numpy.invert)
def logical_not(a: IntoArray) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3069 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3069

Added line #L3069 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.LogicalNot, a, force_dtype=bool)

@implements(numpy.all)
def all(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3076 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3076

Added line #L3076 was not covered by tests
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented

Check warning on line 3082 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3082

Added line #L3082 was not covered by tests
return _Wrapper(evaluable.Product, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.any)
def any(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3089 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3089

Added line #L3089 was not covered by tests
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented

Check warning on line 3095 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3095

Added line #L3095 was not covered by tests
return _Wrapper(evaluable.Sum, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.sum)
def sum(arg: IntoArray, axis: Optional[Union[int, Sequence[int]]] = None) -> Array:
arg = Array.cast(arg)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ def _check(name, op, n_op, *args):
_check('max', lambda a, b: numpy.maximum(a, function.Array.cast(b)), numpy.maximum, ANY(4, 1), ANY(1, 4))
_check('heaviside', function.heaviside, lambda u: numpy.heaviside(u, .5), ANY(4, 4))

_check('logical_and', lambda a, b: numpy.logical_and(a, function.Array.cast(b)), numpy.logical_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_and-bool', lambda a, b: numpy.bitwise_and(a, function.Array.cast(b)), numpy.bitwise_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_or', lambda a, b: numpy.logical_or(a, function.Array.cast(b)), numpy.logical_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_or-bool', lambda a, b: numpy.bitwise_or(a, function.Array.cast(b)), numpy.bitwise_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_not', lambda a: numpy.logical_not(function.Array.cast(a)), numpy.logical_not, numpy.array([[False, True], [True, False], [True, True]]))
_check('invert-bool', lambda a: numpy.invert(function.Array.cast(a)), numpy.invert, numpy.array([[False, True], [True, False], [True, True]]))
_check('all-bool-all-axes', lambda a: numpy.all(function.Array.cast(a)), numpy.all, numpy.array([[False, True], [True, True]]))
_check('all-bool-single-axis', lambda a: numpy.all(function.Array.cast(a), axis=0), lambda a: numpy.all(a, axis=0), numpy.array([[False, True], [True, True]]))
_check('any-bool-all-axes', lambda a: numpy.any(function.Array.cast(a)), numpy.any, numpy.array([[False, True], [True, True]]))
_check('any-bool-single-axis', lambda a: numpy.any(function.Array.cast(a), axis=0), lambda a: numpy.any(a, axis=0), numpy.array([[False, True], [False, False]]))

## TODO: opposite
## TODO: mean
## TODO: jump
Expand Down

0 comments on commit 75eb670

Please sign in to comment.