Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor arithmetic power #886

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions mathics/builtin/arithfns/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

"""


import sympy

from mathics.builtin.arithmetic import create_infix
from mathics.builtin.base import (
BinaryOperator,
Expand Down Expand Up @@ -45,7 +48,6 @@
Symbol,
SymbolDivide,
SymbolHoldForm,
SymbolNull,
SymbolPower,
SymbolTimes,
)
Expand All @@ -56,10 +58,17 @@
SymbolInfix,
SymbolLeft,
SymbolMinus,
SymbolOverflow,
SymbolPattern,
SymbolSequence,
)
from mathics.eval.arithmetic import eval_Plus, eval_Times
from mathics.eval.arithmetic import (
associate_powers,
eval_Exponential,
eval_Plus,
eval_Power_inexact,
eval_Power_number,
eval_Times,
)
from mathics.eval.nevaluator import eval_N
from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -535,15 +544,15 @@ class Power(BinaryOperator, MPMathFunction):
# Remember to up sympy doc link when this is corrected
sympy_name = "Pow"

def eval_exp(self, x, evaluation):
"Power[E, x]"
return eval_Exponential(x)

def eval_check(self, x, y, evaluation):
"Power[x_, y_]"

# Power uses MPMathFunction but does some error checking first
if isinstance(x, Number) and x.is_zero:
if isinstance(y, Number):
y_err = y
else:
y_err = eval_N(y, evaluation)
# if x is zero
if x.is_zero:
y_err = y if isinstance(y, Number) else eval_N(y, evaluation)
if isinstance(y_err, Number):
py_y = y_err.round_to_float(permit_complex=True).real
if py_y > 0:
Expand All @@ -557,17 +566,47 @@ def eval_check(self, x, y, evaluation):
evaluation.message(
"Power", "infy", Expression(SymbolPower, x, y_err)
)
return SymbolComplexInfinity
if isinstance(x, Complex) and x.real.is_zero:
yhalf = Expression(SymbolTimes, y, RationalOneHalf)
factor = self.eval(Expression(SymbolSequence, x.imag, y), evaluation)
return Expression(
SymbolTimes, factor, Expression(SymbolPower, IntegerM1, yhalf)
)

result = self.eval(Expression(SymbolSequence, x, y), evaluation)
if result is None or result != SymbolNull:
return result
return SymbolComplexInfinity

# If x and y are inexact numbers, use the numerical function

if x.is_inexact() and y.is_inexact():
try:
return eval_Power_inexact(x, y)
except OverflowError:
evaluation.message("General", "ovfl")
return Expression(SymbolOverflow)

# Tries to associate powers a^b^c-> a^(b*c)
assoc = associate_powers(x, y)
if not assoc.has_form("Power", 2):
return assoc

assoc = numerify(assoc, evaluation)
x, y = assoc.elements
# If x and y are numbers
if isinstance(x, Number) and isinstance(y, Number):
try:
return eval_Power_number(x, y)
except OverflowError:
evaluation.message("General", "ovfl")
return Expression(SymbolOverflow)

# if x or y are inexact, leave the expression
# as it is:
if x.is_inexact() or y.is_inexact():
return assoc

# Finally, try to convert to sympy
base_sp, exp_sp = x.to_sympy(), y.to_sympy()
if base_sp is None or exp_sp is None:
# If base or exp can not be converted to sympy,
# returns the result of applying the associative
# rule.
return assoc

result = from_sympy(sympy.Pow(base_sp, exp_sp))
return result.evaluate_elements(evaluation)


class Sqrt(SympyFunction):
Expand Down
105 changes: 102 additions & 3 deletions mathics/eval/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-

"""
arithmetic-related evaluation functions.
helper functions for arithmetic evaluation, which do not
depends on the evaluation context. Conversions to Sympy are
used just as a last resource.

Many of these do do depend on the evaluation context. Conversions to Sympy are
used just as a last resource.
Expand Down Expand Up @@ -320,6 +322,28 @@ def eval_complex_sign(n: BaseElement) -> Optional[BaseElement]:
return sign or eval_complex_sign(expr)


def eval_Sign_number(n: Number) -> Number:
"""
Evals the absolute value of a number.
"""
if n.is_zero:
return Integer0
if isinstance(n, (Integer, Rational, Real)):
return Integer1 if n.value > 0 else IntegerM1
if isinstance(n, Complex):
abs_sq = eval_add_numbers(
*(eval_multiply_numbers(x, x) for x in (n.real, n.imag))
)
criteria = eval_add_numbers(abs_sq, IntegerM1)
if test_zero_arithmetic_expr(criteria):
return n
if n.is_inexact():
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RealM0p5))
if test_zero_arithmetic_expr(criteria, numeric=True):
return n
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RationalMOneHalf))


def eval_mpmath_function(
mpmath_function: Callable, *args: Number, prec: Optional[int] = None
) -> Optional[Number]:
Expand Down Expand Up @@ -347,6 +371,31 @@ def eval_mpmath_function(
return call_mpmath(mpmath_function, tuple(mpmath_args), prec)


def eval_Exponential(exp: BaseElement) -> BaseElement:
"""
Eval E^exp
"""
# If both base and exponent are exact quantities,
# use sympy.

if not exp.is_inexact():
exp_sp = exp.to_sympy()
if exp_sp is None:
return None
return from_sympy(sympy.Exp(exp_sp))

prec = exp.get_precision()
if prec is not None:
if exp.is_machine_precision():
number = mpmath.exp(exp.to_mpmath())
result = from_mpmath(number)
return result
else:
with mpmath.workprec(prec):
number = mpmath.exp(exp.to_mpmath())
return from_mpmath(number, prec)


def eval_Plus(*items: BaseElement) -> BaseElement:
"evaluate Plus for general elements"
numbers, items_tuple = segregate_numbers_from_sorted_list(*items)
Expand Down Expand Up @@ -645,8 +694,58 @@ def eval_Times(*items: BaseElement) -> BaseElement:
)


def associate_powers(expr: BaseElement, power: BaseElement = Integer1) -> BaseElement:
Copy link
Member

@rocky rocky Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this kind of thing can't be done in Sympy. Have we tried this?

For another thing brought up and in a previous PR, with Sign, I see in that a problem there mentioned in StackOverflow was just about not tagging a variable as being in the right type (complex vs real).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is not if we can find a way to use Sympy, but if we can do it "natively", without the cost of back-and-forth conversions. Even if we can make these conversions more efficient, avoiding them at all would always be a win, mostly if the native implementation is not too complicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

"""
base^a^b^c^...^power -> base^(a*b*c*...power)
provided one of the following cases
* `a`, `b`, ... `power` are all integer numbers
* `a`, `b`,... are Rational/Real number with absolute value <=1,
and the other powers are not integer numbers.
* `a` is not a Rational/Real number, and b, c, ... power are all
integer numbers.
"""
powers = []
base = expr
if power is not Integer1:
powers.append(power)

while base.has_form("Power", 2):
previous_base, outer_power = base, power
base, power = base.elements
if len(powers) == 0:
if power is not Integer1:
powers.append(power)
continue
if power is IntegerM1:
powers.append(power)
continue
if isinstance(power, (Rational, Real)):
if abs(power.value) < 1:
powers.append(power)
continue
# power is not rational/real and outer_power is integer,
elif isinstance(outer_power, Integer):
if power is not Integer1:
powers.append(power)
if isinstance(power, Integer):
continue
else:
break
# in any other case, use the previous base and
# exit the loop
base = previous_base
break

if len(powers) == 0:
return base
elif len(powers) == 1:
return Expression(SymbolPower, base, powers[0])
result = Expression(SymbolPower, base, Expression(SymbolTimes, *powers))
return result


def eval_add_numbers(
*numbers: Number,
*numbers: List[Number],
) -> BaseElement:
"""
Add the elements in ``numbers``.
Expand Down Expand Up @@ -693,7 +792,7 @@ def eval_inverse_number(n: Number) -> Number:
return eval_Power_number(n, IntegerM1)


def eval_multiply_numbers(*numbers: Number) -> Number:
def eval_multiply_numbers(*numbers: Number) -> BaseElement:
"""
Multiply the elements in ``numbers``.
"""
Expand Down
8 changes: 4 additions & 4 deletions test/builtin/arithmetic/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
("I^(2/3)", "(-1) ^ (1 / 3)", None),
# In WMA, the next test would return ``-(-I)^(2/3)``
# which is less compact and elegant...
# ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. master returns -1 right now, so what we have is wrong.

Do you know what part of the code fixes this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is related to the new rules you mention before...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop. The rule at the end seems to be useless. The computation is done by mathics.eval.arithmetic.eval_Power_number, which now is called from Power.eval_check ( mathics.builtin.arithfns.basic, line 591).

("(2+3I)^3", "-46 + 9 I", None),
("(1.+3. I)^.6", "1.46069 + 1.35921 I", None),
("3^(1+2 I)", "3 ^ (1 + 2 I)", None),
Expand All @@ -208,15 +208,15 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
# sympy, which produces the result
("(3/Pi)^(-I)", "(3 / Pi) ^ (-I)", None),
# Association rules
# ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
("(a^2)^(1/2)", "Sqrt[a ^ 2]", None),
("(a^(1/2))^2", "a", None),
("(a^(1/2))^2", "a", None),
("(a^(3/2))^3.", "(a ^ (3 / 2)) ^ 3.", None),
# ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
# ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
("(a^(1.3))^3.", "(a ^ 1.3) ^ 3.", None),
# Exponentials involving expressions
("(a^(p-2 q))^3", "a ^ (3 p - 6 q)", None),
Expand Down
43 changes: 31 additions & 12 deletions test/format/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,34 +456,53 @@
"Sqrt[1/(1+1/(1+1/a))]": {
"msg": "SqrtBox",
"text": {
"System`StandardForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
"System`TraditionalForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
"System`InputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
"System`OutputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
"System`StandardForm": "1 / Sqrt[1+1 / (1+1 / a)]",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change in test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the effect of implementing the associativity of powers. In master,

Sqrt[1 / (1 + 1 / (1 + 1 / a))]//FullForm

produces

Power[Power[Plus[1, Power[Plus[1, Power[a, -1]], -1]], -1], Rational[1, 2]]

while here,

Power[Plus[1, Power[Plus[1, Power[a, -1]], -1]], Rational[-1, 2]]

On the other hand, looking at WMA, it seems that it does not implement the associativity in this case either.

"System`TraditionalForm": "1 / Sqrt[1+1 / (1+1 / a)]",
"System`InputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
"System`OutputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
},
"mathml": {
"System`StandardForm": (
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
(
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
r"</msqrt></mfrac>"
),
"Fragile!",
),
"System`TraditionalForm": (
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
(
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
r"</msqrt></mfrac>"
),
"Fragile!",
),
"System`InputForm": (
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
(
r"<mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>[</mo> "
r"<mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> "
r"<mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>"
r"&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
),
"Fragile!",
),
"System`OutputForm": (
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
(
r"<mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>["
r"</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> "
r"<mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>"
r"&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> "
r"<mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
),
"Fragile!",
),
},
"latex": {
"System`StandardForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`TraditionalForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`InputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
"System`OutputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
"System`StandardForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`TraditionalForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`InputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
"System`OutputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
},
},
# Grids, arrays and matrices
Expand Down