Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert RootSum to SymPy #1136

Merged
merged 5 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions mathics/builtin/functional/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@

from itertools import chain

import sympy

from mathics.core.atoms import Integer, Integer1
from mathics.core.attributes import A_HOLD_ALL, A_N_HOLD_ALL, A_PROTECTED
from mathics.core.builtin import Builtin, PostfixOperator
from mathics.core.builtin import Builtin, PostfixOperator, SympyFunction
from mathics.core.convert.sympy import SymbolFunction
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.symbols import Symbol
from mathics.core.symbols import Symbol, sympy_slot_prefix
from mathics.core.systemsymbols import SymbolSlot


class Function(PostfixOperator):
class Function(PostfixOperator, SympyFunction):
"""
<dl>
<dt>'Function[$body$]'
Expand Down Expand Up @@ -119,9 +123,11 @@ def eval_named(self, vars, body, args, evaluation: Evaluation):
# this is not included in WL, and here does not have any impact, but it is needed for
# translating the function to a compiled version.
var_names = (
var.get_name()
if isinstance(var, Symbol)
else var.elements[0].get_name()
(
var.get_name()
if isinstance(var, Symbol)
else var.elements[0].get_name()
)
for var in vars
)
vars = dict(list(zip(var_names, args[: len(vars)])))
Expand All @@ -148,8 +154,17 @@ def eval_named_attr(self, vars, body, attr, args, evaluation: Evaluation):
except Exception:
return

def to_sympy(self, expr: Expression, **kwargs):
if len(expr.elements) == 1:
body = expr.elements[0]
slot = Expression(SymbolSlot, Integer1)
return sympy.Lambda(slot.to_sympy(), body.to_sympy())
else:
# TODO: Handle multiple and/or named arguments
raise NotImplementedError


class Slot(Builtin):
class Slot(SympyFunction):
"""
<dl>
<dt>'#$n$'
Expand Down Expand Up @@ -184,6 +199,10 @@ class Slot(Builtin):
}
summary_text = "one argument of a pure function"

def to_sympy(self, expr: Expression, **kwargs):
index: Integer = expr.elements[0]
return sympy.Symbol(f"{sympy_slot_prefix}{index.get_int_value()}")


class SlotSequence(Builtin):
"""
Expand Down
4 changes: 3 additions & 1 deletion mathics/builtin/list/constructing.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,12 @@ class Normal(Builtin):

summary_text = "convert objects to normal expressions"

def eval_general(self, expr, evaluation: Evaluation):
def eval_general(self, expr: Expression, evaluation: Evaluation):
"Normal[expr_]"
if isinstance(expr, Atom):
return
if expr.has_form("RootSum", 2):
return from_sympy(expr.to_sympy().doit(roots=True))
return Expression(
expr.get_head(),
*[Expression(SymbolNormal, element) for element in expr.elements],
Expand Down
57 changes: 55 additions & 2 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
from mathics.core.convert.expression import to_expression, to_mathics_list
from mathics.core.convert.function import expression_to_callable_and_args
from mathics.core.convert.python import from_python
from mathics.core.convert.sympy import SympyExpression, from_sympy, sympy_symbol_prefix
from mathics.core.convert.sympy import (
SymbolRootSum,
SympyExpression,
from_sympy,
sympy_symbol_prefix,
)
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
Expand All @@ -63,6 +68,7 @@
SymbolConditionalExpression,
SymbolD,
SymbolDerivative,
SymbolFunction,
SymbolIndeterminate,
SymbolInfinity,
SymbolInfix,
Expand All @@ -76,6 +82,7 @@
SymbolSeries,
SymbolSeriesData,
SymbolSimplify,
SymbolSlot,
SymbolUndefined,
)
from mathics.eval.makeboxes import format_element
Expand Down Expand Up @@ -1607,7 +1614,7 @@ class Root(SympyFunction):

Roots that can't be represented by radicals:
>> Root[#1 ^ 5 + 2 #1 + 1&, 2]
= Root[#1 ^ 5 + 2 #1 + 1&, 2]
= Root[1 + #1 ^ 5 + 2 #1&, 2]
"""

messages = {
Expand Down Expand Up @@ -1671,6 +1678,52 @@ def to_sympy(self, expr, **kwargs):
return None


class RootSum(SympyFunction):
"""
<url>:WMA link: https://reference.wolfram.com/language/ref/RootSum.html</url>

<dl>
<dt>'RootSum[$f$, $form$]'
<dd>sums $form[x]$ for all roots of the polynomial $f[x]$.
</dl>

>> Integrate[1/(x^5 + 11 x + 1), {x, 1, 3}]
= RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3749971 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&] - RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3748721 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&]
>> N[%, 50]
= 0.051278805184286949884270940103072421286139857550894

>> RootSum[#^5 - 11 # + 1 &, (#^2 - 1)/(#^3 - 2 # + c) &]
= (538 - 88 c + 396 c ^ 2 + 5 c ^ 3 - 5 c ^ 4) / (97 - 529 c - 53 c ^ 2 + 88 c ^ 3 + c ^ 5)

>> RootSum[#^5 - 3 # - 7 &, Sin] //N//Chop
= 0.292188

Use Normal to expand RootSum:
>> RootSum[1+#+#^2+#^3+#^4 &, Log[x + #] &]
= RootSum[1 + #1 ^ 2 + #1 ^ 3 + #1 ^ 4 + #1&, Log[x + #1]&]
>> %//Normal
= Log[-1 / 4 - Sqrt[5] / 4 - I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - Sqrt[5] / 4 + I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x] + Log[-1 / 4 + I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x]
"""

summary_text = "sum polynomial roots"

def eval(self, f, form, evaluation: Evaluation): # type: ignore[override]
"RootSum[f_, form_]"
return from_sympy(Expression(SymbolRootSum, f, form).to_sympy())

def to_sympy(self, expr: Expression, **kwargs):
func = expr.elements[1]
if not isinstance(func.to_sympy(), sympy.Lambda):
# eta conversion
func = Expression(
SymbolFunction, Expression(func, Expression(SymbolSlot, Integer1))
)

poly = expr.elements[0].to_sympy()
poly_x = sympy.Symbol("poly_x")
return sympy.RootSum(poly(poly_x), func.to_sympy(), x=poly_x)


class Series(Builtin):
"""
<url>:WMA link:https://reference.wolfram.com/language/ref/Series.html</url>
Expand Down
3 changes: 2 additions & 1 deletion mathics/core/convert/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __new__(cls, *exprs):
if all(isinstance(expr, BasicSympy) for expr in exprs):
# called with SymPy arguments
obj = super().__new__(cls, *exprs)
obj.expr = None
elif len(exprs) == 1 and isinstance(exprs[0], Expression):
# called with Mathics argument
expr = exprs[0]
Expand Down Expand Up @@ -460,7 +461,7 @@ def old_from_sympy(expr) -> BaseElement:
result.append(Expression(SymbolTimes, *factors))
else:
result.append(Integer1)
return Expression(SymbolFunction, Expression(SymbolPlus, *result))
return Expression(SymbolFunction, Expression(SymbolPlus, *sorted(result)))
if isinstance(expr, sympy.CRootOf):
try:
e_root, indx = expr.args
Expand Down
2 changes: 1 addition & 1 deletion mathics/eval/nevaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def eval_NValues(

# Special case for the Root builtin
# This should be implemented as an NValue
if expr.has_form("Root", 2):
if expr.has_form("Root", 2) or expr.has_form("RootSum", 2):
return from_sympy(sympy.N(expr.to_sympy(), d))

# Here we look for the NValues associated to the
Expand Down
3 changes: 2 additions & 1 deletion mathics/eval/numbers/algebra/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _default_complexity_function(x):

# At this point, ``complexity_function`` is a function that takes a
# sympy expression and returns an integer.
sympy_result = simplify(sympy_expr, measure=complexity_function)
sympy_result = simplify(sympy_expr, measure=complexity_function, doit=False)
sympy_result = sympy_result.doit(roots=False) # Don't expand RootSum

# and bring it back
result = from_sympy(sympy_result).evaluate(evaluation)
Expand Down