From 6d3d7c7c1b125cecb60f92104a64c9c16e0b5519 Mon Sep 17 00:00:00 2001 From: BlankShrimp Date: Wed, 13 Sep 2023 21:04:52 +0800 Subject: [PATCH] reconstruct Solve logic alternate old plain solve logic with a recursive one. Now Solve is capable with nested logics --- mathics/builtin/numbers/calculus.py | 301 ++++++++++++++-------------- 1 file changed, 156 insertions(+), 145 deletions(-) diff --git a/mathics/builtin/numbers/calculus.py b/mathics/builtin/numbers/calculus.py index a7b6b5958..286a6ff42 100644 --- a/mathics/builtin/numbers/calculus.py +++ b/mathics/builtin/numbers/calculus.py @@ -330,9 +330,9 @@ def summand(element, index): Expression( SymbolDerivative, *( - [Integer0] * (index) - + [Integer1] - + [Integer0] * (len(f.elements) - index - 1) + [Integer0] * (index) + + [Integer1] + + [Integer0] * (len(f.elements) - index - 1) ), ), f.head, @@ -664,8 +664,8 @@ def eval(self, f, x, x0, evaluation: Evaluation, options: dict): # Determine the "jacobian"s if ( - method in ("Newton", "Automatic") - and options["System`Jacobian"] is SymbolAutomatic + method in ("Newton", "Automatic") and + options["System`Jacobian"] is SymbolAutomatic ): def diff(evaluation): @@ -1321,16 +1321,16 @@ class NIntegrate(Builtin): messages = { "bdmtd": "The Method option should be a built-in method name.", "inumr": ( - "The integrand `1` has evaluated to non-numerical " - + "values for all sampling points in the region " - + "with boundaries `2`" + "The integrand `1` has evaluated to non-numerical " + + "values for all sampling points in the region " + + "with boundaries `2`" ), "nlim": "`1` = `2` is not a valid limit of integration.", "ilim": "Invalid integration variable or limit(s) in `1`.", "mtdfail": ( - "The specified method failed to return a " - + "number. Falling back into the internal " - + "evaluator." + "The specified method failed to return a " + + "number. Falling back into the internal " + + "evaluator." ), "cmpint": ("Integration over a complex domain is not " + "implemented yet"), } @@ -1373,10 +1373,10 @@ class NIntegrate(Builtin): messages.update( { - "bdmtd": "The Method option should be a " - + "built-in method name in {`" - + "`, `".join(list(methods)) - + "`}. Using `Automatic`" + "bdmtd": "The Method option should be a " + + "built-in method name in {`" + + "`, `".join(list(methods)) + + "`}. Using `Automatic`" } ) @@ -1396,7 +1396,7 @@ def eval_with_func_domain( elif isinstance(method, Symbol): method = method.get_name() # strip context - method = method[method.rindex("`") + 1 :] + method = method[method.rindex("`") + 1:] else: evaluation.message("NIntegrate", "bdmtd", method) return @@ -2235,146 +2235,157 @@ def eval(self, eqs, vars, evaluation: Evaluation): vars = [vars] for var in vars: if ( - (isinstance(var, Atom) and not isinstance(var, Symbol)) - or head_name in ("System`Plus", "System`Times", "System`Power") # noqa - or A_CONSTANT & var.get_attributes(evaluation.definitions) + (isinstance(var, Atom) and not isinstance(var, Symbol)) or + head_name in ("System`Plus", "System`Times", "System`Power") or # noqa + A_CONSTANT & var.get_attributes(evaluation.definitions) ): evaluation.message("Solve", "ivar", vars_original) return - if eqs.get_head_name() in ("System`List", "System`And"): - eq_list = eqs.elements - else: - eq_list = [eqs] - sympy_conditions = [] - sympy_eqs = [] - sympy_denoms = [] - for eq in eq_list: - if eq is SymbolTrue: - pass - elif eq is SymbolFalse: - return ListExpression() - elif not eq.has_form("Equal", 2): - sympy_conditions.append(eq.to_sympy()) - else: - left, right = eq.elements - left = left.to_sympy() - right = right.to_sympy() - if left is None or right is None: - return - eq = left - right - eq = sympy.together(eq) - eq = sympy.cancel(eq) - sympy_eqs.append(eq) - numer, denom = eq.as_numer_denom() - sympy_denoms.append(denom) - - if not sympy_eqs: - evaluation.message("Solve", "eqf", eqs) - return vars_sympy = [var.to_sympy() for var in vars] if None in vars_sympy: + evaluation.message("Solve", "ivar") return - - # delete unused variables to avoid SymPy's - # PolynomialError: Not a zero-dimensional system - # in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] - all_vars = vars[:] - all_vars_sympy = vars_sympy[:] - vars = [] - vars_sympy = [] - for var, var_sympy in zip(all_vars, all_vars_sympy): - pattern = Pattern.create(var) - if not eqs.is_free(pattern, evaluation): - vars.append(var) - vars_sympy.append(var_sympy) - - def transform_dict(sols): - if not sols: - yield sols - for var, sol in sols.items(): - rest = sols.copy() - del rest[var] - rest = transform_dict(rest) - if not isinstance(sol, (tuple, list)): - sol = [sol] - if not sol: - for r in rest: - yield r - else: - for r in rest: - for item in sol: - new_sols = r.copy() - new_sols[var] = item - yield new_sols - break - - def transform_solution(sol): - if not isinstance(sol, dict): - if not isinstance(sol, (list, tuple)): - sol = [sol] - sol = dict(list(zip(vars_sympy, sol))) - return transform_dict(sol) - - if not sympy_eqs: - sympy_eqs = True - elif len(sympy_eqs) == 1: - sympy_eqs = sympy_eqs[0] - - try: - if isinstance(sympy_eqs, bool): - result = sympy_eqs + all_var_tuples = list(zip(vars, vars_sympy)) + + def cut_var_dimension(expressions: Expression | list[Expression]): + '''delete unused variables to avoid SymPy's PolynomialError + : Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]''' + if not isinstance(expressions, list): + expressions = [expressions] + subset_vars = set() + subset_vars_sympy = set() + for var, var_sympy in all_var_tuples: + pattern = Pattern.create(var) + for equation in expressions: + if not equation.is_free(pattern, evaluation): + subset_vars.add(var) + subset_vars_sympy.add(var_sympy) + return subset_vars, subset_vars_sympy + + def solve_sympy(equations: Expression | list[Expression]): + if not isinstance(equations, list): + equations = [equations] + equations_sympy = [] + denoms_sympy = [] + subset_vars, subset_vars_sympy = cut_var_dimension(equations) + for equation in equations: + if equation is SymbolTrue: + continue + elif equation is SymbolFalse: + return [] + elements = equation.elements + for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]: + # ↑ to deal with things like a==b==c==d + left = left.to_sympy() + right = right.to_sympy() + if left is None or right is None: + return [] + equation_sympy = left - right + equation_sympy = sympy.together(equation_sympy) + equation_sympy = sympy.cancel(equation_sympy) + numer, denom = equation_sympy.as_numer_denom() + denoms_sympy.append(denom) + try: + results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform needed with dict=True + # Filter out results for which denominator is 0 + # (SymPy should actually do that itself, but it doesn't!) + results = [ + sol + for sol in results + if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy) + ] + return results + except sympy.PolynomialError: + # raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting + # unused variables beforehand + return [] + except NotImplementedError: + return [] + except TypeError as exc: + if str(exc).startswith("expected Symbol, Function or Derivative"): + evaluation.message("Solve", "ivar", vars_original) + + def solve_recur(expression: Expression): + '''solve And, Or and List within the scope of sympy, + but including the translation from Mathics to sympy + + returns: + solutions: a list of sympy solution dictionaries + conditions: a sympy condition object + + note: + for And and List, should always return either (solutions, None) or ([], conditions) + for Or, all combinations are possible. if Or is root, should be handled outside''' + head = expression.get_head_name() + if head in ("System`And", "System`List"): + solutions = [] + equations: list[Expression] = [] + inequations = [] + for child in expression.elements: + if child.has_form("Equal", 2): + equations.append(child) + elif child.get_head_name() in ("System`And", "System`Or"): + sub_solution, sub_condition = solve_recur(child) + solutions.extend(sub_solution) + if sub_condition is not None: + inequations.append(sub_condition) + else: + inequations.append(child.to_sympy()) + solutions.extend(solve_sympy(equations)) + conditions = sympy.And(*inequations) + result = [sol for sol in solutions if conditions.subs(sol)] + return result, None if solutions else conditions + else: # should be System`Or then + assert head == "System`Or" + solutions = [] + conditions = [] + for child in expression.elements: + if child.has_form("Equal", 2): + solutions.extend(solve_sympy(child)) + elif child.get_head_name() in ("System`And", "System`Or"): # List wouldn't be in here + sub_solution, sub_condition = solve_recur(child) + solutions.extend(sub_solution) + if sub_condition is not None: + conditions.append(sub_condition) + else: + # SymbolTrue and SymbolFalse are allowed here since it's subtree context + # FIXME: None is not allowed, not sure what to do here + conditions.append(child.to_sympy()) + conditions = sympy.Or(*conditions) + return solutions, conditions + + if eqs.get_head_name() in ("System`List", "System`And", "System`Or"): + solutions, conditions = solve_recur(eqs) + # non True conditions are only accepted in subtrees, not root + if conditions is not None: + evaluation.message("Solve", "fulldim") + return ListExpression(ListExpression()) + else: + if eqs.has_form("Equal", 2): + solutions = solve_sympy(eqs) else: - result = sympy.solve(sympy_eqs, vars_sympy) - if not isinstance(result, list): - result = [result] - if isinstance(result, list) and len(result) == 1 and result[0] is True: + evaluation.message("Solve", "fulldim") return ListExpression(ListExpression()) - if result == [None]: - return ListExpression() - results = [] - for sol in result: - results.extend(transform_solution(sol)) - result = results - # filter with conditions before further translation - conditions = sympy.And(*sympy_conditions) - result = [sol for sol in result if conditions.subs(sol)] - - if any( - sol and any(var not in sol for var in all_vars_sympy) for sol in result - ): - evaluation.message("Solve", "svars") - # Filter out results for which denominator is 0 - # (SymPy should actually do that itself, but it doesn't!) - result = [ - sol - for sol in result - if all(sympy.simplify(denom.subs(sol)) != 0 for denom in sympy_denoms) - ] - - return ListExpression( - *( - ListExpression( - *( - Expression(SymbolRule, var, from_sympy(sol[var_sympy])) - for var, var_sympy in zip(vars, vars_sympy) - if var_sympy in sol - ), - ) - for sol in result - ), - ) - except sympy.PolynomialError: - # raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting - # unused variables beforehand - pass - except NotImplementedError: - pass - except TypeError as exc: - if str(exc).startswith("expected Symbol, Function or Derivative"): - evaluation.message("Solve", "ivar", vars_original) + if any( + sol and any(var not in sol for var in vars_sympy) for sol in solutions + ): + evaluation.message("Solve", "svars") + + return ListExpression( + *( + ListExpression( + *( + Expression(SymbolRule, var, from_sympy(sol[var_sympy])) + for var, var_sympy in zip(vars, all_var_tuples) + if var_sympy in sol + ), + ) + for sol in solutions + ), + ) # Auxiliary routines. Maybe should be moved to another module.