Skip to content

Commit

Permalink
reconstruct Solve logic
Browse files Browse the repository at this point in the history
alternate old plain solve logic with a recursive one. Now Solve is capable with nested logics
  • Loading branch information
BlankShrimp committed Sep 13, 2023
1 parent 00d39ee commit 6d3d7c7
Showing 1 changed file with 156 additions and 145 deletions.
301 changes: 156 additions & 145 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ def summand(element, index):
Expression(
SymbolDerivative,
*(
[Integer0] * (index)
+ [Integer1]
+ [Integer0] * (len(f.elements) - index - 1)
[Integer0] * (index) +
[Integer1] +
[Integer0] * (len(f.elements) - index - 1)
),
),
f.head,
Expand Down Expand Up @@ -664,8 +664,8 @@ def eval(self, f, x, x0, evaluation: Evaluation, options: dict):

# Determine the "jacobian"s
if (
method in ("Newton", "Automatic")
and options["System`Jacobian"] is SymbolAutomatic
method in ("Newton", "Automatic") and
options["System`Jacobian"] is SymbolAutomatic
):

def diff(evaluation):
Expand Down Expand Up @@ -1321,16 +1321,16 @@ class NIntegrate(Builtin):
messages = {
"bdmtd": "The Method option should be a built-in method name.",
"inumr": (
"The integrand `1` has evaluated to non-numerical "
+ "values for all sampling points in the region "
+ "with boundaries `2`"
"The integrand `1` has evaluated to non-numerical " +
"values for all sampling points in the region " +
"with boundaries `2`"
),
"nlim": "`1` = `2` is not a valid limit of integration.",
"ilim": "Invalid integration variable or limit(s) in `1`.",
"mtdfail": (
"The specified method failed to return a "
+ "number. Falling back into the internal "
+ "evaluator."
"The specified method failed to return a " +
"number. Falling back into the internal " +
"evaluator."
),
"cmpint": ("Integration over a complex domain is not " + "implemented yet"),
}
Expand Down Expand Up @@ -1373,10 +1373,10 @@ class NIntegrate(Builtin):

messages.update(
{
"bdmtd": "The Method option should be a "
+ "built-in method name in {`"
+ "`, `".join(list(methods))
+ "`}. Using `Automatic`"
"bdmtd": "The Method option should be a " +
"built-in method name in {`" +
"`, `".join(list(methods)) +
"`}. Using `Automatic`"
}
)

Expand All @@ -1396,7 +1396,7 @@ def eval_with_func_domain(
elif isinstance(method, Symbol):
method = method.get_name()
# strip context
method = method[method.rindex("`") + 1 :]
method = method[method.rindex("`") + 1:]
else:
evaluation.message("NIntegrate", "bdmtd", method)
return
Expand Down Expand Up @@ -2235,146 +2235,157 @@ def eval(self, eqs, vars, evaluation: Evaluation):
vars = [vars]
for var in vars:
if (
(isinstance(var, Atom) and not isinstance(var, Symbol))
or head_name in ("System`Plus", "System`Times", "System`Power") # noqa
or A_CONSTANT & var.get_attributes(evaluation.definitions)
(isinstance(var, Atom) and not isinstance(var, Symbol)) or
head_name in ("System`Plus", "System`Times", "System`Power") or # noqa
A_CONSTANT & var.get_attributes(evaluation.definitions)
):

evaluation.message("Solve", "ivar", vars_original)
return
if eqs.get_head_name() in ("System`List", "System`And"):
eq_list = eqs.elements
else:
eq_list = [eqs]
sympy_conditions = []
sympy_eqs = []
sympy_denoms = []
for eq in eq_list:
if eq is SymbolTrue:
pass
elif eq is SymbolFalse:
return ListExpression()
elif not eq.has_form("Equal", 2):
sympy_conditions.append(eq.to_sympy())
else:
left, right = eq.elements
left = left.to_sympy()
right = right.to_sympy()
if left is None or right is None:
return
eq = left - right
eq = sympy.together(eq)
eq = sympy.cancel(eq)
sympy_eqs.append(eq)
numer, denom = eq.as_numer_denom()
sympy_denoms.append(denom)

if not sympy_eqs:
evaluation.message("Solve", "eqf", eqs)
return

vars_sympy = [var.to_sympy() for var in vars]
if None in vars_sympy:
evaluation.message("Solve", "ivar")
return

# delete unused variables to avoid SymPy's
# PolynomialError: Not a zero-dimensional system
# in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]
all_vars = vars[:]
all_vars_sympy = vars_sympy[:]
vars = []
vars_sympy = []
for var, var_sympy in zip(all_vars, all_vars_sympy):
pattern = Pattern.create(var)
if not eqs.is_free(pattern, evaluation):
vars.append(var)
vars_sympy.append(var_sympy)

def transform_dict(sols):
if not sols:
yield sols
for var, sol in sols.items():
rest = sols.copy()
del rest[var]
rest = transform_dict(rest)
if not isinstance(sol, (tuple, list)):
sol = [sol]
if not sol:
for r in rest:
yield r
else:
for r in rest:
for item in sol:
new_sols = r.copy()
new_sols[var] = item
yield new_sols
break

def transform_solution(sol):
if not isinstance(sol, dict):
if not isinstance(sol, (list, tuple)):
sol = [sol]
sol = dict(list(zip(vars_sympy, sol)))
return transform_dict(sol)

if not sympy_eqs:
sympy_eqs = True
elif len(sympy_eqs) == 1:
sympy_eqs = sympy_eqs[0]

try:
if isinstance(sympy_eqs, bool):
result = sympy_eqs
all_var_tuples = list(zip(vars, vars_sympy))

def cut_var_dimension(expressions: Expression | list[Expression]):
'''delete unused variables to avoid SymPy's PolynomialError
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
if not isinstance(expressions, list):
expressions = [expressions]
subset_vars = set()
subset_vars_sympy = set()
for var, var_sympy in all_var_tuples:
pattern = Pattern.create(var)
for equation in expressions:
if not equation.is_free(pattern, evaluation):
subset_vars.add(var)
subset_vars_sympy.add(var_sympy)
return subset_vars, subset_vars_sympy

def solve_sympy(equations: Expression | list[Expression]):
if not isinstance(equations, list):
equations = [equations]
equations_sympy = []
denoms_sympy = []
subset_vars, subset_vars_sympy = cut_var_dimension(equations)
for equation in equations:
if equation is SymbolTrue:
continue
elif equation is SymbolFalse:
return []
elements = equation.elements
for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]:
# ↑ to deal with things like a==b==c==d
left = left.to_sympy()
right = right.to_sympy()
if left is None or right is None:
return []
equation_sympy = left - right
equation_sympy = sympy.together(equation_sympy)
equation_sympy = sympy.cancel(equation_sympy)
numer, denom = equation_sympy.as_numer_denom()
denoms_sympy.append(denom)
try:
results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform needed with dict=True
# Filter out results for which denominator is 0
# (SymPy should actually do that itself, but it doesn't!)
results = [
sol
for sol in results
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy)
]
return results
except sympy.PolynomialError:
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
# unused variables beforehand
return []
except NotImplementedError:
return []
except TypeError as exc:
if str(exc).startswith("expected Symbol, Function or Derivative"):
evaluation.message("Solve", "ivar", vars_original)

def solve_recur(expression: Expression):
'''solve And, Or and List within the scope of sympy,
but including the translation from Mathics to sympy
returns:
solutions: a list of sympy solution dictionaries
conditions: a sympy condition object
note:
for And and List, should always return either (solutions, None) or ([], conditions)
for Or, all combinations are possible. if Or is root, should be handled outside'''
head = expression.get_head_name()
if head in ("System`And", "System`List"):
solutions = []
equations: list[Expression] = []
inequations = []
for child in expression.elements:
if child.has_form("Equal", 2):
equations.append(child)
elif child.get_head_name() in ("System`And", "System`Or"):
sub_solution, sub_condition = solve_recur(child)
solutions.extend(sub_solution)
if sub_condition is not None:
inequations.append(sub_condition)
else:
inequations.append(child.to_sympy())
solutions.extend(solve_sympy(equations))
conditions = sympy.And(*inequations)
result = [sol for sol in solutions if conditions.subs(sol)]
return result, None if solutions else conditions
else: # should be System`Or then
assert head == "System`Or"
solutions = []
conditions = []
for child in expression.elements:
if child.has_form("Equal", 2):
solutions.extend(solve_sympy(child))
elif child.get_head_name() in ("System`And", "System`Or"): # List wouldn't be in here
sub_solution, sub_condition = solve_recur(child)
solutions.extend(sub_solution)
if sub_condition is not None:
conditions.append(sub_condition)
else:
# SymbolTrue and SymbolFalse are allowed here since it's subtree context
# FIXME: None is not allowed, not sure what to do here
conditions.append(child.to_sympy())
conditions = sympy.Or(*conditions)
return solutions, conditions

if eqs.get_head_name() in ("System`List", "System`And", "System`Or"):
solutions, conditions = solve_recur(eqs)
# non True conditions are only accepted in subtrees, not root
if conditions is not None:
evaluation.message("Solve", "fulldim")
return ListExpression(ListExpression())
else:
if eqs.has_form("Equal", 2):
solutions = solve_sympy(eqs)
else:
result = sympy.solve(sympy_eqs, vars_sympy)
if not isinstance(result, list):
result = [result]
if isinstance(result, list) and len(result) == 1 and result[0] is True:
evaluation.message("Solve", "fulldim")
return ListExpression(ListExpression())
if result == [None]:
return ListExpression()
results = []
for sol in result:
results.extend(transform_solution(sol))
result = results
# filter with conditions before further translation
conditions = sympy.And(*sympy_conditions)
result = [sol for sol in result if conditions.subs(sol)]

if any(
sol and any(var not in sol for var in all_vars_sympy) for sol in result
):
evaluation.message("Solve", "svars")

# Filter out results for which denominator is 0
# (SymPy should actually do that itself, but it doesn't!)
result = [
sol
for sol in result
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in sympy_denoms)
]

return ListExpression(
*(
ListExpression(
*(
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
for var, var_sympy in zip(vars, vars_sympy)
if var_sympy in sol
),
)
for sol in result
),
)
except sympy.PolynomialError:
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
# unused variables beforehand
pass
except NotImplementedError:
pass
except TypeError as exc:
if str(exc).startswith("expected Symbol, Function or Derivative"):
evaluation.message("Solve", "ivar", vars_original)
if any(
sol and any(var not in sol for var in vars_sympy) for sol in solutions
):
evaluation.message("Solve", "svars")

return ListExpression(
*(
ListExpression(
*(
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
for var, var_sympy in zip(vars, all_var_tuples)
if var_sympy in sol
),
)
for sol in solutions
),
)


# Auxiliary routines. Maybe should be moved to another module.
Expand Down

0 comments on commit 6d3d7c7

Please sign in to comment.