diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d8eb88d4f..333cb489d 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -417,7 +417,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: def name(self) -> str | None: return self._name - def producer(self) -> None | NodePattern: + def producer(self) -> NodePattern | None: return None def uses(self) -> Sequence[tuple[NodePattern, int]]: @@ -970,6 +970,7 @@ def __str__(self) -> str: class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: GraphPattern) -> None: super().__init__(pattern) + self._current_node: ir.Node | None = None def fail(self, reason: str, node: ir.Node | None = None) -> bool: if self._verbose: @@ -1128,7 +1129,7 @@ def _init_match(self, verbose: int) -> None: self._verbose = verbose self._matched: dict[NodePattern, ir.Node] = {} self._match: MatchResult = MatchResult() - self._current_node: ir.Node | None = None + self._current_node = None def _get_output_values(self) -> list[ir.Value] | None: """Get values bound to the output variables of the pattern."""