diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 66d9b3196..2b02bab1b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1040,19 +1040,19 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node - # TODO: Revisit this to handle optional trailing inputs better. - if pattern_node.allow_other_inputs: - if len(node.inputs) < len(pattern_node.inputs): + if len(node.inputs) > len(pattern_node.inputs): + if pattern_node.allow_other_inputs: + # Ignore extraneous inputs + to_match = zip(node.inputs, pattern_node.inputs) + else: return self.fail( - f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})" + f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})" ) else: - if len(node.inputs) != len(pattern_node.inputs): - return self.fail( - f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" - ) + # Inputs are padded with Nones to match against pattern + to_match = itertools.zip_longest(node.inputs, pattern_node.inputs, fillvalue=None) - for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): + for arg_value, arg_pattern in to_match: # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: if arg_value is None: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0247949f5..37545349b 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -476,6 +476,38 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + def test_match_trailing_optional_input(self): + def none_pattern(op, optional_input, x): + # match against a call to Original where the first input may or may not be None + return op.Original(x, optional_input) + + def replacement(op, optional_input, x): + if optional_input is None: + return op.ReplacedNone(x) + return op.ReplacedNotNone(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call (with optional_input == None) + t1 = op.Original(x, None) + # as well as this one (with optional_input != None) + t2 = op.Original(x, t1) + # as well as this one (with optional_input == None) + z = op.Original(t2) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 3) + self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") + self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + self.assertEqual(model.graph.node(2).op_type, "ReplacedNone") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):