Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate Relationals when casting to bool #354

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
26 changes: 26 additions & 0 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,32 @@ class Relational(Boolean):
def is_Relational(self):
return True

def __bool__(self):
# We will narrow down the boolean value of our relational with some simple checks
# Get the Left- and Right-hand-sides of the relation, since two expressions are equal if their difference
# is equal to 0.
# If the expand method will not cancel out free symbols in the given expression, then this
# will throw a TypeError.
lhs, rhs = self.args
difference = (lhs - rhs).expand()

if len(difference.free_symbols):
# If there are any free symbols, then boolean evaluation is ambiguous in most cases. Throw a Type Error
raise TypeError(f'Relational with free symbols cannot be cast as bool: {self}')
else:
# Instantiating relationals that are obviously True or False (according to symengine) will automatically
# simplify to BooleanTrue or BooleanFalse
relational_type = type(self)
simplified = relational_type(difference, S.Zero)
if isinstance(simplified, BooleanAtom):
return bool(simplified)
# If we still cannot determine whether or not the relational is true, then we can either outsource the
# evaluation to sympy (if available) or raise a ValueError expressing that the evaluation is unclear.
try:
return bool(self.simplify())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will return True for 2*(x + 1) - 2 which is different from what we get for 2*x which is a TypeError.

except ImportError:
raise ValueError(f'Boolean evaluation is unclear for relational: {self}')

Rel = Relational


Expand Down
1 change: 1 addition & 0 deletions symengine/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ install(FILES __init__.py
test_matrices.py
test_ntheory.py
test_printing.py
test_relationals.py
test_sage.py
test_series_expansion.py
test_sets.py
Expand Down
2 changes: 1 addition & 1 deletion symengine/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_eval_double2():
x = Symbol("x")
e = sin(x)**2 + sqrt(2)
raises(RuntimeError, lambda: e.n(real=True))
assert abs(e.n() - x**2 - 1.414) < 1e-3
assert abs(e.n() - sin(x)**2.0 - 1.414) < 1e-3

def test_n():
x = Symbol("x")
Expand Down
138 changes: 138 additions & 0 deletions symengine/tests/test_relationals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from symengine.utilities import raises
from symengine import (Symbol, sympify, Eq, Ne, Lt, Le, Ge, Gt, sqrt, pi)

from unittest.case import SkipTest

try:
import sympy
HAVE_SYMPY = True
except ImportError:
HAVE_SYMPY = False


def assert_equal(x, y):
"""Asserts that x and y are equal. This will test Equality, Unequality, LE, and GE classes."""
assert bool(Eq(x, y))
assert not bool(Ne(x, y))
assert bool(Ge(x, y))
assert bool(Le(x, y))


def assert_not_equal(x, y):
"""Asserts that x and y are not equal. This will test Equality and Unequality"""
assert not bool(Eq(x, y))
assert bool(Ne(x, y))


def assert_less_than(x, y):
"""Asserts that x is less than y. This will test Le, Lt, Ge, Gt classes."""
assert bool(Le(x, y))
assert bool(Lt(x, y))
assert not bool(Ge(x, y))
assert not bool(Gt(x, y))


def assert_greater_than(x, y):
"""Asserts that x is greater than y. This will test Le, Lt, Ge, Gt classes."""
assert not bool(Le(x, y))
assert not bool(Lt(x, y))
assert bool(Ge(x, y))
assert bool(Gt(x, y))


def test_equals_constants_easy():
assert_equal(3, 3)
assert_equal(4, 2 ** 2)


def test_equals_constants_hard():
# Short and long are symbolically equivalent, but sufficiently different in form that expand() does not
# catch it. Ideally, our equality should still catch these, but until symengine supports as robust simplification as
# sympy, we can forgive failing, as long as it raises a ValueError
short = sympify('(3/2)*sqrt(11 + sqrt(21))')
long = sympify('sqrt((33/8 + (1/24)*sqrt(27)*sqrt(63))**2 + ((3/8)*sqrt(27) + (-1/8)*sqrt(63))**2)')
assert_equal(short, short)
assert_equal(long, long)
if HAVE_SYMPY:
assert_equal(short, long)
else:
raises(ValueError, lambda: bool(Eq(short, long)))


def test_not_equals_constants():
assert_not_equal(3, 4)
assert_not_equal(4, 4 - .000000001)


def test_equals_symbols():
x = Symbol("x")
y = Symbol("y")
assert_equal(x, x)
assert_equal(x ** 2, x * x)
assert_equal(x * y, y * x)


def test_not_equals_symbols():
x = Symbol("x")
y = Symbol("y")
assert_not_equal(x, x + 1)
assert_not_equal(x ** 2, x ** 2 + 1)
assert_not_equal(x * y, y * x + 1)


def test_not_equals_symbols_raise_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Eq(x, 1)))
raises(TypeError, lambda: bool(Eq(x, y)))
raises(TypeError, lambda: bool(Eq(x ** 2, x)))


def test_less_than_constants_easy():
assert_less_than(1, 2)
assert_less_than(-1, 1)


def test_less_than_constants_hard():
# Each of the below pairs are distinct numbers, with the one on the left less than the one on the right.
# Ideally, Less-than will catch this when evaluated, but until symengine has a more robust simplification,
# we can forgive a failure to evaluate as long as it raises a ValueError.
if HAVE_SYMPY:
assert_less_than(sqrt(2), 2)
assert_less_than(3.14, pi)
else:
raises(ValueError, lambda: bool(Lt(sqrt(2), 2)))
raises(ValueError, lambda: bool(Lt(3.14, pi)))


def test_greater_than_constants():
assert_greater_than(2, 1)
assert_greater_than(1, -1)


def test_greater_than_constants_hard():
# Each of the below pairs are distinct numbers, with the one on the left less than the one on the right.
# Ideally, Greater-than will catch this when evaluated, but until symengine has a more robust simplification,
# we can forgive a failure to evaluate as long as it raises a ValueError.
if HAVE_SYMPY:
assert_greater_than(2, sqrt(2))
assert_greater_than(pi, 3.14)
else:
raises(ValueError, lambda: bool(Gt(2, sqrt(2))))
raises(ValueError, lambda: bool(Gt(pi, 3.14)))


def test_less_than_raises_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Lt(x, 1)))
raises(TypeError, lambda: bool(Lt(x, y)))
raises(TypeError, lambda: bool(Lt(x ** 2, x)))


def test_greater_than_raises_typeerror():
x = Symbol("x")
y = Symbol("y")
raises(TypeError, lambda: bool(Gt(x, 1)))
raises(TypeError, lambda: bool(Gt(x, y)))
raises(TypeError, lambda: bool(Gt(x ** 2, x)))