Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 / Fix issues with pattern matching for Rubi #1176

Merged
merged 10 commits into from
Nov 23, 2024
17 changes: 14 additions & 3 deletions mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,12 @@ def match_expression_with_one_identity(
isinstance(pat_elem, PatternObject)
and pat_elem.get_head() == SymbolOptional
):
if len(pat_elem.elements) == 2:
if optionals:
# A default pattern already exists
# Do not use the second one
if new_pattern is None:
new_pattern = pat_elem
elif len(pat_elem.elements) == 2:
pat, value = pat_elem.elements
if isinstance(pat, Pattern):
key = pat.elements[0].atom.name # type: ignore[attr-defined]
Expand All @@ -724,8 +729,12 @@ def match_expression_with_one_identity(
result = defaultvalue_expr.evaluate(evaluation)
assert result is not None
if result.sameQ(defaultvalue_expr):
return
optionals[key] = result
if new_pattern is None:
# The optional pattern has no default value
# for the given position
new_pattern = pat_elem
else:
optionals[key] = result
else:
return
elif new_pattern is not None:
Expand Down Expand Up @@ -757,6 +766,8 @@ def match_expression_with_one_identity(
del parms["attributes"]
assert new_pattern is not None
new_pattern.match(expression=expression, pattern_context=parms)
for optional in optionals:
vars_dict.pop(optional)


def basic_match_expression(
Expand Down
52 changes: 52 additions & 0 deletions test/core/test_patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""
Unit tests for mathics pattern matching
"""

import sys
import time
from test.helper import check_evaluation, evaluate

import pytest


@pytest.mark.parametrize(
("str_expr", "msgs", "str_expected", "fail_msg"),
[
# Two default arguments (linear)
("MatchQ[1, a_.+b_.*x_]", None, "True", None),
("MatchQ[x, a_.+b_.*x_]", None, "True", None),
("MatchQ[2*x, a_.+b_.*x_]", None, "True", None),
("MatchQ[1+x, a_.+b_.*x_]", None, "True", None),
("MatchQ[1+2*x, a_.+b_.*x_]", None, "True", None),
# Default argument (power)
("MatchQ[1, x_^m_.]", None, "True", None),
("MatchQ[x, x_^m_.]", None, "True", None),
("MatchQ[x^1, x_^m_.]", None, "True", None),
("MatchQ[x^2, x_^m_.]", None, "True", None),
# Two default arguments (power)
("MatchQ[1, x_.^m_.]", None, "True", None),
("MatchQ[x, x_.^m_.]", None, "True", None),
("MatchQ[x^1, x_.^m_.]", None, "True", None),
("MatchQ[x^2, x_.^m_.]", None, "True", None),
# Two default arguments (no non-head)
("MatchQ[1, a_.+b_.]", None, "True", None),
("MatchQ[x, a_.+b_.]", None, "True", None),
("MatchQ[1+x, a_.+b_.]", None, "True", None),
("MatchQ[1+2*x, a_.+b_.]", None, "True", None),
("MatchQ[1, a_.+b_.]", None, "True", None),
("MatchQ[x, a_.*b_.]", None, "True", None),
("MatchQ[2*x, a_.*b_.]", None, "True", None),
],
)
def test_patterns(str_expr, msgs, str_expected, fail_msg):
"""patterns"""
check_evaluation(
str_expr,
str_expected,
to_string_expr=True,
to_string_expected=True,
hold_expected=True,
failure_message=fail_msg,
expected_messages=msgs,
)