Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

boolean operators #844

Merged
merged 5 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 46 additions & 21 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,11 @@
__pow__ = power
__abs__ = lambda self: abs(self)
__mod__ = lambda self, other: mod(self, other)
__and__ = __rand__ = multiply
__or__ = __ror__ = add
__int__ = __index__
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str))
__inv__ = lambda self: LogicalNot(self) if self.dtype == bool else NotImplemented

def _shape_str(self, form):
dtype = self.dtype.__name__[0] if hasattr(self, 'dtype') else '?'
Expand Down Expand Up @@ -1173,7 +1176,7 @@
return constant(self.value.transpose(axes))

def _sum(self, axis):
return constant(numpy.sum(self.value, axis))
return constant((numpy.any if self.dtype == bool else numpy.sum)(self.value, axis))

def _add(self, other):
if isinstance(other, Constant):
Expand All @@ -1186,7 +1189,7 @@
return constant(numpy.transpose(numpy.linalg.inv(value), numpy.argsort(axes)))

def _product(self):
return constant(self.value.prod(-1))
return constant((numpy.all if self.dtype == bool else numpy.prod)(self.value, -1))

def _multiply(self, other):
if self._isunit:
Expand Down Expand Up @@ -1278,11 +1281,11 @@

def _sum(self, i):
if i == self.ndim - 1:
return self.func * astype(self.length, self.func.dtype)
return self.func if self.dtype == bool else self.func * astype(self.length, self.func.dtype)
return InsertAxis(sum(self.func, i), self.length)

def _product(self):
return self.func**astype(self.length, self.func.dtype)
return self.func if self.dtype == bool else self.func**astype(self.length, self.func.dtype)

def _power(self, n):
unaligned1, unaligned2, where = unalign(self, n)
Expand Down Expand Up @@ -1555,19 +1558,16 @@
class Product(Array):

def __init__(self, func: Array):
assert isinstance(func, Array) and func.dtype != bool, f'func={func!r}'
assert isinstance(func, Array), f'func={func!r}'
self.func = func
self.evalf = functools.partial(numpy.all if func.dtype == bool else numpy.prod, axis=-1)
super().__init__(args=(func,), shape=func.shape[:-1], dtype=func.dtype)

def _simplified(self):
if _equals_scalar_constant(self.func.shape[-1], 1):
return get(self.func, self.ndim, constant(0))
return self.func._product()

@staticmethod
def evalf(arr):
return numpy.product(arr, axis=-1)

def _derivative(self, var, seen):
grad = derivative(self.func, var, seen)
funcs = Product(insertaxis(self.func, -2, self.func.shape[-1]) + Diagonalize(astype(1, self.func.dtype) - self.func)) # replace diagonal entries by 1
Expand Down Expand Up @@ -1657,7 +1657,7 @@
assert isinstance(funcs, types.frozenmultiset), f'funcs={funcs!r}'
self.funcs = funcs
func1, func2 = funcs
assert equalshape(func1.shape, func2.shape) and func1.dtype == func2.dtype != bool, 'Multiply({}, {})'.format(func1, func2)
assert equalshape(func1.shape, func2.shape) and func1.dtype == func2.dtype, 'Multiply({}, {})'.format(func1, func2)
super().__init__(args=tuple(self.funcs), shape=func1.shape, dtype=func1.dtype)

@property
Expand All @@ -1678,16 +1678,20 @@
for axis1, axis2, *other in map(sorted, fj._diagonals):
return diagonalize(multiply(*(takediag(f, axis1, axis2) for f in factors)), axis1, axis2)
for i, fi in enumerate(factors[:j]):
if self.dtype == bool and fi == fj:
return multiply(*factors[:j], *factors[j+1:])
unaligned1, unaligned2, where = unalign(fi, fj)
fij = align(unaligned1 * unaligned2, where, self.shape) if len(where) != self.ndim \
else fi._multiply(fj) or fj._multiply(fi)
if fij:
return multiply(*factors[:i], *factors[i+1:j], *factors[j+1:], fij)

def _optimized_for_numpy(self):
if self.dtype == bool:
return None
factors = tuple(self._factors)
for i, fi in enumerate(factors):
if fi.dtype != bool and fi._const_uniform == -1:
if fi._const_uniform == -1:
return Negative(multiply(*factors[:i], *factors[i+1:]))
if fi.dtype != complex and Sign(fi) in factors:
i, j = sorted([i, factors.index(Sign(fi))])
Expand Down Expand Up @@ -1722,8 +1726,8 @@
return multiply(*common) * add(multiply(*factors), multiply(*other_factors))
nz = factors or other_factors
if not nz: # self equals other (up to factor ordering)
return self * astype(2, self.dtype)
if len(nz) == 1 and tuple(nz)[0]._const_uniform == -1:
return self if self.dtype == bool else self * astype(2, self.dtype)
if self.dtype != bool and len(nz) == 1 and tuple(nz)[0]._const_uniform == -1:
# Since the subtraction x - y is stored as x + -1 * y, this handles
# the simplification of x - x to 0. While we could alternatively
# simplify all x + a * x to (a + 1) * x, capturing a == -1 as a
Expand Down Expand Up @@ -1772,6 +1776,8 @@

@cached_property
def _assparse(self):
if self.dtype == bool:
return super()._assparse
# First we collect the clusters of factors that have no real (i.e. not
# inserted) axes in common with the other clusters, and store them in
# uninserted form.
Expand Down Expand Up @@ -1821,11 +1827,13 @@
assert isinstance(funcs, types.frozenmultiset) and len(funcs) == 2, f'funcs={funcs!r}'
self.funcs = funcs
func1, func2 = funcs
assert equalshape(func1.shape, func2.shape) and func1.dtype == func2.dtype != bool, 'Add({}, {})'.format(func1, func2)
assert equalshape(func1.shape, func2.shape) and func1.dtype == func2.dtype, 'Add({}, {})'.format(func1, func2)
super().__init__(args=tuple(self.funcs), shape=func1.shape, dtype=func1.dtype)

@cached_property
def _inflations(self):
if self.dtype == bool:
return ()
func1, func2 = self.funcs
func2_inflations = dict(func2._inflations)
inflations = []
Expand Down Expand Up @@ -1856,6 +1864,8 @@
terms = tuple(self._terms)
for j, fj in enumerate(terms):
for i, fi in enumerate(terms[:j]):
if self.dtype == bool and fi == fj:
return add(*terms[:j], *terms[j+1:])
diags = [sorted(axesi & axesj)[:2] for axesi in fi._diagonals for axesj in fj._diagonals if len(axesi & axesj) >= 2]
unaligned1, unaligned2, where = unalign(fi, fj)
fij = diagonalize(takediag(fi, *diags[0]) + takediag(fj, *diags[0]), *diags[0]) if diags \
Expand Down Expand Up @@ -1919,7 +1929,10 @@

@cached_property
def _assparse(self):
return _gathersparsechunks(itertools.chain(*[f._assparse for f in self._terms]))
if self.dtype == bool:
return super()._assparse
else:
return _gathersparsechunks(itertools.chain(*[f._assparse for f in self._terms]))

def _intbounds_impl(self):
lowers, uppers = zip(*[f._intbounds for f in self._terms])
Expand Down Expand Up @@ -1986,19 +1999,15 @@

def __init__(self, func: Array):
assert isinstance(func, Array), f'func={func!r}'
assert func.dtype != bool, 'Sum({})'.format(func)
self.func = func
self.evalf = functools.partial(numpy.any if func.dtype == bool else numpy.sum, axis=-1)
super().__init__(args=(func,), shape=func.shape[:-1], dtype=func.dtype)

def _simplified(self):
if _equals_scalar_constant(self.func.shape[-1], 1):
return Take(self.func, constant(0))
return self.func._sum(self.ndim)

@staticmethod
def evalf(arr):
return numpy.sum(arr, -1)

def _sum(self, axis):
trysum = self.func._sum(axis)
if trysum is not None:
Expand All @@ -2009,6 +2018,8 @@

@cached_property
def _assparse(self):
if self.dtype == bool:
return super()._assparse
chunks = []
for *indices, _rmidx, values in self.func._assparse:
if self.ndim == 0:
Expand Down Expand Up @@ -2490,6 +2501,20 @@
return bool


class LogicalNot(Pointwise):
evalf = staticmethod(numpy.logical_not)
def return_type(T):
if T != bool:
raise ValueError(f'Expected a boolean but got {T}.')

Check warning on line 2508 in nutils/evaluable.py

View check run for this annotation

Codecov / codecov/patch

nutils/evaluable.py#L2508

Added line #L2508 was not covered by tests
return bool

def _simplified(self):
arg, = self.args
if isinstance(arg, LogicalNot):
return arg.args[0]
return super()._simplified()


class Minimum(Pointwise):
evalf = staticmethod(numpy.minimum)
deriv = lambda x, y: .5 - .5 * Sign(x - y), lambda x, y: .5 + .5 * Sign(x - y)
Expand Down Expand Up @@ -2896,7 +2921,7 @@
return Zeros(self.shape+(self.shape[axis],), dtype=self.dtype)

def _sum(self, axis):
return Zeros(self.shape[:axis] + self.shape[axis+1:], dtype=int if self.dtype == bool else self.dtype)
return Zeros(self.shape[:axis] + self.shape[axis+1:], dtype=self.dtype)

def _transpose(self, axes):
shape = tuple(self.shape[n] for n in axes)
Expand Down
50 changes: 50 additions & 0 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3045,6 +3045,56 @@
raise ValueError('Complex numbers have no total order.')
return _Wrapper.broadcasted_arrays(evaluable.Maximum, a, b)

@implements(numpy.logical_and)
@implements(numpy.bitwise_and)
def logical_and(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented

Check warning on line 3053 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3053

Added line #L3053 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.multiply, a, b)

@implements(numpy.logical_or)
@implements(numpy.bitwise_or)
def logical_or(a: IntoArray, b: IntoArray) -> Array:
a, b = map(Array.cast, (a, b))
if a.dtype != bool or b.dtype != bool:
return NotImplemented

Check warning on line 3061 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3061

Added line #L3061 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.add, a, b)

@implements(numpy.logical_not)
@implements(numpy.invert)
def logical_not(a: IntoArray) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3069 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3069

Added line #L3069 was not covered by tests
return _Wrapper.broadcasted_arrays(evaluable.LogicalNot, a, force_dtype=bool)

@implements(numpy.all)
def all(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3076 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3076

Added line #L3076 was not covered by tests
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented

Check warning on line 3082 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3082

Added line #L3082 was not covered by tests
return _Wrapper(evaluable.Product, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.any)
def any(a: IntoArray, axis = None) -> Array:
a = Array.cast(a)
if a.dtype != bool:
return NotImplemented

Check warning on line 3089 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3089

Added line #L3089 was not covered by tests
if axis is None:
a = numpy.ravel(a)
elif isinstance(axis, int):
a = _Transpose.to_end(a, axis)
else:
return NotImplemented

Check warning on line 3095 in nutils/function.py

View check run for this annotation

Codecov / codecov/patch

nutils/function.py#L3095

Added line #L3095 was not covered by tests
return _Wrapper(evaluable.Sum, a, shape=a.shape[:-1], dtype=bool)

@implements(numpy.sum)
def sum(arg: IntoArray, axis: Optional[Union[int, Sequence[int]]] = None) -> Array:
arg = Array.cast(arg)
Expand Down
29 changes: 20 additions & 9 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_product(self):
return
for iax in range(self.actual.ndim):
self.assertFunctionAlmostEqual(decimal=14,
desired=numpy.product(self.desired, axis=iax),
desired=(numpy.all if self.actual.dtype == bool else numpy.prod)(self.desired, axis=iax),
actual=evaluable.product(self.actual, axis=iax))

def test_getslice(self):
Expand All @@ -248,23 +248,17 @@ def test_getslice(self):
actual=self.actual[s])

def test_sumaxis(self):
if self.desired.dtype == bool:
return
for idim in range(self.actual.ndim):
self.assertFunctionAlmostEqual(decimal=14,
desired=self.desired.sum(idim),
desired=(numpy.any if self.actual.dtype == bool else numpy.sum)(self.desired, axis=idim),
actual=self.actual.sum(idim))

def test_add(self):
if self.actual.dtype == bool:
return
self.assertFunctionAlmostEqual(decimal=14,
desired=self.desired + self.other,
actual=(self.actual + self.other))

def test_multiply(self):
if self.actual.dtype == bool:
return
self.assertFunctionAlmostEqual(decimal=14,
desired=self.desired * self.other,
actual=(self.actual * self.other))
Expand Down Expand Up @@ -561,6 +555,11 @@ def _check(name, op, n_op, *arg_values, hasgrad=True, zerograd=False, ndim=2):
_check('equal', evaluable.Equal, numpy.equal, ANY(4, 4), ANY(4, 4), zerograd=True)
_check('greater', evaluable.Greater, numpy.greater, ANY(4, 4), ANY(4, 4), zerograd=True)
_check('less', evaluable.Less, numpy.less, ANY(4, 4), ANY(4, 4), zerograd=True)
_check('logical_and', evaluable.multiply, numpy.logical_and, numpy.array([[False, False], [True, True]], dtype=bool), numpy.array([[False, True], [False, True]], dtype=bool))
_check('logical_or', evaluable.add, numpy.logical_or, numpy.array([[False, False], [True, True], [False, True]], dtype=bool), numpy.array([[False, True], [False, True], [True, False]], dtype=bool))
_check('logical_not', evaluable.LogicalNot, numpy.logical_not, numpy.array([[False, False], [True, True], [False, True]], dtype=bool))
_check('logical_any', evaluable.Sum, lambda a: numpy.any(a, axis=-1), numpy.array([[False, False], [True, True], [False, True]], dtype=bool))
_check('logical_all', evaluable.Product, lambda a: numpy.all(a, axis=-1), numpy.array([[False, False], [True, True], [False, True]], dtype=bool))
_check('arctan2', evaluable.arctan2, numpy.arctan2, ANY(4, 4), ANY(4, 4))
_check('stack', lambda a, b: evaluable.stack([a, b], 0), lambda a, b: numpy.concatenate([a[numpy.newaxis, :], b[numpy.newaxis, :]], axis=0), ANY(4), ANY(4))
_check('eig', lambda a: evaluable.eig(a+a.swapaxes(0, 1), symmetric=True)[1], lambda a: numpy.linalg.eigh(a+a.swapaxes(0, 1))[1], ANY(4, 4), hasgrad=False)
Expand Down Expand Up @@ -1050,6 +1049,18 @@ def test_swap_take_inflate(self):
taken = evaluable.Take(inflated, indices=evaluable.constant([1]))
self.assertTrue(evaluable.iszero(taken))

def test_double_logical_not(self):
a = evaluable.Argument('test', shape=(), dtype=bool)
self.assertEqual(evaluable.LogicalNot(evaluable.LogicalNot(a)).simplified, a)

def test_logical_or_same_args(self):
a = evaluable.Argument('test', shape=(), dtype=bool)
self.assertEqual((a | a).simplified, a)

def test_logical_and_same_args(self):
a = evaluable.Argument('test', shape=(), dtype=bool)
self.assertEqual((a & a).simplified, a)


class memory(TestCase):

Expand Down Expand Up @@ -1353,7 +1364,7 @@ def test(self):
%0 = EVALARGS --> dict
%1 = nutils.evaluable.Constant<f:> --> ndarray<f:>
%2 = nutils.evaluable.Constant<f:2> --> ndarray<f:2>
%3 = nutils.evaluable.Sum<f:> arr=%2 --> float64
%3 = nutils.evaluable.Sum<f:> a=%2 --> float64
%4 = nutils.evaluable.Add<f:> %1 %3 --> float64
%5 = tests.test_evaluable.Fail<i:> arg1=%4 arg2=%1 --> operation failed intentially.''')

Expand Down
11 changes: 11 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ def _check(name, op, n_op, *args):
_check('max', lambda a, b: numpy.maximum(a, function.Array.cast(b)), numpy.maximum, ANY(4, 1), ANY(1, 4))
_check('heaviside', function.heaviside, lambda u: numpy.heaviside(u, .5), ANY(4, 4))

_check('logical_and', lambda a, b: numpy.logical_and(a, function.Array.cast(b)), numpy.logical_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_and-bool', lambda a, b: numpy.bitwise_and(a, function.Array.cast(b)), numpy.bitwise_and, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_or', lambda a, b: numpy.logical_or(a, function.Array.cast(b)), numpy.logical_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('bitwise_or-bool', lambda a, b: numpy.bitwise_or(a, function.Array.cast(b)), numpy.bitwise_or, numpy.array([False, True, True])[:,None], numpy.array([True, False])[None,:])
_check('logical_not', lambda a: numpy.logical_not(function.Array.cast(a)), numpy.logical_not, numpy.array([[False, True], [True, False], [True, True]]))
_check('invert-bool', lambda a: numpy.invert(function.Array.cast(a)), numpy.invert, numpy.array([[False, True], [True, False], [True, True]]))
_check('all-bool-all-axes', lambda a: numpy.all(function.Array.cast(a)), numpy.all, numpy.array([[False, True], [True, True]]))
_check('all-bool-single-axis', lambda a: numpy.all(function.Array.cast(a), axis=0), lambda a: numpy.all(a, axis=0), numpy.array([[False, True], [True, True]]))
_check('any-bool-all-axes', lambda a: numpy.any(function.Array.cast(a)), numpy.any, numpy.array([[False, True], [True, True]]))
_check('any-bool-single-axis', lambda a: numpy.any(function.Array.cast(a), axis=0), lambda a: numpy.any(a, axis=0), numpy.array([[False, True], [False, False]]))

## TODO: opposite
## TODO: mean
## TODO: jump
Expand Down
Loading