diff --git a/mathics/builtin/exp_structure/general.py b/mathics/builtin/exp_structure/general.py index 70404a036..49b9fcc85 100644 --- a/mathics/builtin/exp_structure/general.py +++ b/mathics/builtin/exp_structure/general.py @@ -114,7 +114,7 @@ class FreeQ(Builtin): def eval(self, expr, form, evaluation: Evaluation): "FreeQ[expr_, form_]" - form = BasePattern.create(form) + form = BasePattern.create(form, evaluation=evaluation) if expr.is_free(form, evaluation): return SymbolTrue else: diff --git a/mathics/builtin/list/constructing.py b/mathics/builtin/list/constructing.py index ff135e53a..014f467d5 100644 --- a/mathics/builtin/list/constructing.py +++ b/mathics/builtin/list/constructing.py @@ -431,7 +431,10 @@ def eval(self, expr, patterns, f, evaluation: Evaluation): "Reap[expr_, {patterns___}, f_]" patterns = patterns.get_sequence() - sown = [(BasePattern.create(pattern), []) for pattern in patterns] + sown = [ + (BasePattern.create(pattern, evaluation=evaluation), []) + for pattern in patterns + ] def listener(e, tag): result = False diff --git a/mathics/builtin/list/eol.py b/mathics/builtin/list/eol.py index 01ef261eb..9d12e1b62 100644 --- a/mathics/builtin/list/eol.py +++ b/mathics/builtin/list/eol.py @@ -213,7 +213,7 @@ def eval(self, items, pattern, ls, evaluation, options): results = [] if pattern.has_form("Rule", 2) or pattern.has_form("RuleDelayed", 2): - match = Matcher(pattern.elements[0]).match + match = Matcher(pattern.elements[0], evaluation).match rule = Rule(pattern.elements[0], pattern.elements[1]) def callback(level): @@ -224,7 +224,7 @@ def callback(level): return level else: - match = Matcher(pattern).match + match = Matcher(pattern, evaluation).match def callback(level): if match(level, evaluation): @@ -467,7 +467,7 @@ def eval_ls_n(self, items, pattern, levelspec, n, evaluation): return deletecases_with_levelspec(items, pattern, evaluation, levelspec, n) # A more efficient way to proceed if levelspec == 1 - match = Matcher(pattern).match + match = Matcher(pattern, evaluation).match if n == -1: def cond(element): @@ -1187,7 +1187,7 @@ def eval(self, items, sel, evaluation): def eval_pattern(self, items, sel, pattern, evaluation): "Pick[items_, sel_, pattern_]" - match = Matcher(pattern).match + match = Matcher(pattern, evaluation).match return self._do(items, sel, lambda s: match(s, evaluation), evaluation) @@ -1245,7 +1245,7 @@ def eval_level(self, expr, patt, ls, evaluation, options={}): evaluation.message("Position", "level", ls) return - match = Matcher(patt).match + match = Matcher(patt, evaluation).match result = [] def callback(level, pos): diff --git a/mathics/builtin/numbers/calculus.py b/mathics/builtin/numbers/calculus.py index 389997715..16a0ea3e4 100644 --- a/mathics/builtin/numbers/calculus.py +++ b/mathics/builtin/numbers/calculus.py @@ -228,7 +228,7 @@ def eval(self, f, x, evaluation: Evaluation): if f == x: return Integer1 - x_pattern = BasePattern.create(x) + x_pattern = BasePattern.create(x, evaluation=evaluation) if f.is_free(x_pattern, evaluation): return Integer0 @@ -1919,7 +1919,7 @@ def eval_times( nummax.get_int_value(), den.get_int_value(), ) - x_pattern = BasePattern.create(x) + x_pattern = BasePattern.create(x, evaluation=evaluation) incompat_series = [] max_exponent = Integer(int(series[2] / series[3] + 1)) if coeff.get_head() is SymbolSequence: @@ -2265,7 +2265,7 @@ def eval(self, eqs, vars, evaluation: Evaluation): vars = [] vars_sympy = [] for var, var_sympy in zip(all_vars, all_vars_sympy): - pattern = BasePattern.create(var) + pattern = BasePattern.create(var, evaluation=evaluation) if not eqs.is_free(pattern, evaluation): vars.append(var) vars_sympy.append(var_sympy) diff --git a/mathics/builtin/options.py b/mathics/builtin/options.py index c1a4e874d..3ba3922e5 100644 --- a/mathics/builtin/options.py +++ b/mathics/builtin/options.py @@ -159,7 +159,7 @@ class FilterRules(Builtin): def eval(self, rules, pattern, evaluation): "FilterRules[rules_List, pattern_]" - match = Matcher(pattern).match + match = Matcher(pattern, evaluation).match def matched(): for rule in rules.elements: diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py index d57d140e5..addd2a600 100644 --- a/mathics/builtin/patterns.py +++ b/mathics/builtin/patterns.py @@ -827,7 +827,7 @@ def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: super(Except, self).init(expr, evaluation=evaluation) - self.c = BasePattern.create(expr.elements[0]) + self.c = BasePattern.create(expr.elements[0], evaluation=evaluation) if len(expr.elements) == 2: self.p = BasePattern.create(expr.elements[1], evaluation=evaluation) else: diff --git a/mathics/core/builtin.py b/mathics/core/builtin.py index 61bc1887d..abd5df3e3 100644 --- a/mathics/core/builtin.py +++ b/mathics/core/builtin.py @@ -222,6 +222,7 @@ def contribute(self, definitions, is_pymodule=False): if not self.context: self.context = "Pymathics`" if is_pymodule else "System`" name = self.get_name() + attributes = self.attributes options = {} # - 'Strict': warn and fail with unsupported options @@ -268,19 +269,41 @@ def contribute(self, definitions, is_pymodule=False): for pattern, function in self.get_functions( prefix="eval", is_pymodule=is_pymodule ): + pat_attr = attributes if pattern.get_head_name() == name else None rules.append( - FunctionApplyRule(name, pattern, function, check_options, system=True) + FunctionApplyRule( + name, + pattern, + function, + check_options, + attributes=pat_attr, + system=True, + ) ) for pattern, function in self.get_functions(is_pymodule=is_pymodule): + pat_attr = attributes if pattern.get_head_name() == name else None rules.append( - FunctionApplyRule(name, pattern, function, check_options, system=True) + FunctionApplyRule( + name, + pattern, + function, + check_options, + attributes=pat_attr, + system=True, + ) ) for pattern_str, replace_str in self.rules.items(): pattern_str = pattern_str % {"name": name} pattern = parse_builtin_rule(pattern_str, definition_class) replace_str = replace_str % {"name": name} + pat_attr = attributes if pattern.get_head_name() == name else None rules.append( - Rule(pattern, parse_builtin_rule(replace_str), system=not is_pymodule) + Rule( + pattern, + parse_builtin_rule(replace_str), + attributes=pat_attr, + system=not is_pymodule, + ) ) box_rules = [] @@ -321,11 +344,14 @@ def contextify_form_name(f): formatvalues = {"": []} for pattern, function in self.get_functions("format_"): forms, pattern = extract_forms(pattern) + pat_attr = attributes if pattern.get_head_name() == name else None for form in forms: if form not in formatvalues: formatvalues[form] = [] formatvalues[form].append( - FunctionApplyRule(name, pattern, function, None, system=True) + FunctionApplyRule( + name, pattern, function, None, attributes=pat_attr, system=True + ) ) for pattern, replace in self.formats.items(): forms, pattern = extract_forms(pattern) @@ -377,7 +403,7 @@ def contextify_form_name(f): rules=rules, formatvalues=formatvalues, messages=messages, - attributes=self.attributes, + attributes=attributes, options=options, defaultvalues=defaults, builtin=self, diff --git a/mathics/core/pattern.py b/mathics/core/pattern.py index 8e3433efa..c96166b28 100644 --- a/mathics/core/pattern.py +++ b/mathics/core/pattern.py @@ -92,6 +92,11 @@ class BasePattern(ABC): expr: BaseElement + # this attribute allows for a faster match algorithm based on sameq. + # Probably we should split ExpressionPattern into two different classes, + # one for literal patterns and the other for "Regular" ExpressionPatterns. + isliteral: bool = False + # TODO: In WMA, when a BasePattern is created, the attributes # from the head are read from the evaluation context and # stored as a part of a rule. @@ -168,7 +173,9 @@ class BasePattern(ABC): # @staticmethod def create( - expr: BaseElement, evaluation: Optional[Evaluation] = None + expr: BaseElement, + attributes: Optional[int] = None, + evaluation: Optional[Evaluation] = None, ) -> "BasePattern": """ If ``expr`` is listed in ``pattern_object`` return the pattern found there. @@ -181,7 +188,7 @@ def create( return pattern_object(expr, evaluation=evaluation) if isinstance(expr, Atom): return AtomPattern(expr, evaluation) - return ExpressionPattern(expr, evaluation) + return ExpressionPattern(expr, attributes, evaluation) def get_attributes(self, definitions): """The attributes of the expression""" @@ -320,6 +327,9 @@ class AtomPattern(BasePattern): A pattern that matches with an atom. """ + # Atoms are always literals + isliteral: bool = True + def __init__(self, expr: Atom, evaluation: Optional[Evaluation] = None) -> None: self.expr = expr self.atom = expr @@ -405,15 +415,22 @@ class ExpressionPattern(BasePattern): attributes: Optional[int] = None - def __init__(self, expr: Expression, evaluation: Optional[Evaluation] = None): + def __init__( + self, + expr: Expression, + attributes: Optional[int] = None, + evaluation: Optional[Evaluation] = None, + ): self.expr = expr head = expr.head - attributes = ( - None if evaluation is None else head.get_attributes(evaluation.definition) - ) + if attributes is None and evaluation: + attributes = head.get_attributes(evaluation.definitions) + self.head = BasePattern.create(head, evaluation=evaluation) + self.elements = [ + BasePattern.create(element, evaluation=evaluation) + for element in expr.elements + ] self.__set_pattern_attributes__(attributes) - self.head = BasePattern.create(head) - self.elements = [BasePattern.create(element) for element in expr.elements] def __set_pattern_attributes__(self, attributes): if attributes is None or self.attributes is not None: @@ -425,6 +442,10 @@ def __set_pattern_attributes__(self, attributes): self.get_pre_choices = get_pre_choices_orderless else: self.get_pre_choices = get_pre_choices_with_order + if not (A_ONE_IDENTITY + A_FLAT) & attributes: + self.isliteral = self.head.isliteral and all( + element.isliteral for element in self.elements + ) def match( self, @@ -439,6 +460,12 @@ def match( ): """Try to match the pattern against an Expression""" evaluation.check_stopped() + if self.isliteral: + if expression.sameQ(self.expr): + # yield vars, None + yield_func(vars_dict, None) + return + if self.attributes is None: self.__set_pattern_attributes__( self.head.get_attributes(evaluation.definitions) diff --git a/mathics/core/rules.py b/mathics/core/rules.py index f22cf1c04..ddf946eaa 100644 --- a/mathics/core/rules.py +++ b/mathics/core/rules.py @@ -97,8 +97,11 @@ def __init__( pattern: Expression, system: bool = False, evaluation: Optional[Evaluation] = None, + attributes: Optional[int] = None, ) -> None: - self.pattern = BasePattern.create(pattern, evaluation=evaluation) + self.pattern = BasePattern.create( + pattern, attributes=attributes, evaluation=evaluation + ) self.system = system def apply( @@ -222,8 +225,11 @@ def __init__( replace: Expression, system=False, evaluation: Optional[Evaluation] = None, + attributes: Optional[int] = None, ) -> None: - super(Rule, self).__init__(pattern, system=system, evaluation=evaluation) + super(Rule, self).__init__( + pattern, system=system, evaluation=evaluation, attributes=attributes + ) self.replace = replace def apply_rule( @@ -310,9 +316,10 @@ def __init__( check_options: Optional[Callable], system: bool = False, evaluation: Optional[Evaluation] = None, + attributes: Optional[int] = None, ) -> None: super(FunctionApplyRule, self).__init__( - pattern, system=system, evaluation=evaluation + pattern, system=system, attributes=attributes, evaluation=evaluation ) self.name = name self.function = function diff --git a/mathics/eval/numbers/calculus/series.py b/mathics/eval/numbers/calculus/series.py index c61b4d0a9..ff6d803b3 100644 --- a/mathics/eval/numbers/calculus/series.py +++ b/mathics/eval/numbers/calculus/series.py @@ -372,7 +372,7 @@ def build_series(f, x, x0, n, evaluation): vars = { x_name: x0, } - x_pattern = BasePattern.create(x) + x_pattern = BasePattern.create(x, evaluation=evaluation) if f.is_free(x_pattern, evaluation): print(x, " not in ", f) diff --git a/mathics/eval/parts.py b/mathics/eval/parts.py index 910f96ed5..a5009a65c 100644 --- a/mathics/eval/parts.py +++ b/mathics/eval/parts.py @@ -560,7 +560,7 @@ def deletecases_with_levelspec(expr, pattern, evaluation, levelspec=1, n=-1): """ nothing = SymbolNothing - match = Matcher(pattern) + match = Matcher(pattern, evaluation) match = match.match if type(levelspec) is int: lsmin = 1 @@ -631,9 +631,8 @@ def find_matching_indices_with_levelspec(expr, pattern, evaluation, levelspec=1, n indicates the number of occurrences to return. By default, it returns all the occurrences. """ - from mathics.builtin.patterns import Matcher - match = Matcher(pattern) + match = Matcher(pattern, evaluation) match = match.match if type(levelspec) is int: lsmin = 0 diff --git a/mathics/eval/patterns.py b/mathics/eval/patterns.py index 649ced81f..0b9e7ee44 100644 --- a/mathics/eval/patterns.py +++ b/mathics/eval/patterns.py @@ -7,11 +7,11 @@ class _StopGeneratorMatchQ(StopGenerator): class Matcher: - def __init__(self, form): + def __init__(self, form, evaluation): if isinstance(form, BasePattern): self.form = form else: - self.form = BasePattern.create(form) + self.form = BasePattern.create(form, evaluation=evaluation) def match(self, expr, evaluation: Evaluation): def yield_func(vars, rest): @@ -25,4 +25,4 @@ def yield_func(vars, rest): def match(expr, form, evaluation: Evaluation): - return Matcher(form).match(expr, evaluation) + return Matcher(form, evaluation).match(expr, evaluation) diff --git a/mathics/eval/testing_expressions.py b/mathics/eval/testing_expressions.py index 0285de12d..5edfe8b8f 100644 --- a/mathics/eval/testing_expressions.py +++ b/mathics/eval/testing_expressions.py @@ -112,7 +112,7 @@ def is_number(sympy_value) -> bool: def check_ArrayQ(expr, pattern, test, evaluation: Evaluation): "Check if expr is an Array which test yields true for each of its elements." - pattern = BasePattern.create(pattern) + pattern = BasePattern.create(pattern, evaluation=evaluation) dims = [len(expr.get_elements())] # to ensure an atom is not an array @@ -152,7 +152,7 @@ def check_SparseArrayQ(expr, pattern, test, evaluation: Evaluation): if not expr.head.sameQ(SymbolSparseArray): return SymbolFalse - pattern = BasePattern.create(pattern) + pattern = BasePattern.create(pattern, evaluation=evaluation) dims, default_value, rules = expr.elements[1:] if not pattern.does_match(Integer(len(dims.elements)), evaluation): return SymbolFalse