From b70c2a96abaf150a26cc693e68d4c08a2f528548 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 15 Nov 2024 06:55:00 -0800 Subject: [PATCH 1/3] Improve handling of trailing optional inputs in pattern matching --- onnxscript/rewriter/pattern.py | 19 +++++++---------- onnxscript/rewriter/pattern_test.py | 32 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 66d9b3196..2ad26eded 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1040,19 +1040,14 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node - # TODO: Revisit this to handle optional trailing inputs better. - if pattern_node.allow_other_inputs: - if len(node.inputs) < len(pattern_node.inputs): - return self.fail( - f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})" - ) - else: - if len(node.inputs) != len(pattern_node.inputs): - return self.fail( - f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" - ) + if len(node.inputs) > len(pattern_node.inputs) and not pattern_node.allow_other_inputs: + return self.fail( + f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})" + ) - for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): + for arg_value, arg_pattern in itertools.zip_longest( + node.inputs, pattern_node.inputs, fillvalue=None + ): # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: if arg_value is None: diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0247949f5..57bbd8dc3 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -476,6 +476,38 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + def test_match_trailing_optional_input(self): + def none_pattern(op, optional_input, x): + # match against a call to Original where the first input may or may not be None + return op.Original(x, optional_input) + + def replacement(op, optional_input, x): + if optional_input is None: + return op.ReplacedNone(x) + return op.ReplacedNotNone(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call (with optional_input == None) + t1 = op.Original(x, None) + # as well as this one (with optional_input != None) + z = op.Original(x, t1) + # as well as this one (with optional_input == None) + z = op.Original(x) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 3) + self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") + self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + self.assertEqual(model.graph.node(2).op_type, "ReplacedNone") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 80e2ed401158e5b8210584fba00312469de565c8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 15 Nov 2024 07:10:19 -0800 Subject: [PATCH 2/3] Improve variable naming --- onnxscript/rewriter/pattern_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 57bbd8dc3..37545349b 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -493,9 +493,9 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: # Pattern should match following call (with optional_input == None) t1 = op.Original(x, None) # as well as this one (with optional_input != None) - z = op.Original(x, t1) + t2 = op.Original(x, t1) # as well as this one (with optional_input == None) - z = op.Original(x) + z = op.Original(t2) return z model_proto = test_model.to_model_proto() From d5899ed309bea5a495f655975a7d70642a1bd909 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 15 Nov 2024 07:48:43 -0800 Subject: [PATCH 3/3] Fix padding --- onnxscript/rewriter/pattern.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 2ad26eded..2b02bab1b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1040,14 +1040,19 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node - if len(node.inputs) > len(pattern_node.inputs) and not pattern_node.allow_other_inputs: - return self.fail( - f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})" - ) + if len(node.inputs) > len(pattern_node.inputs): + if pattern_node.allow_other_inputs: + # Ignore extraneous inputs + to_match = zip(node.inputs, pattern_node.inputs) + else: + return self.fail( + f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})" + ) + else: + # Inputs are padded with Nones to match against pattern + to_match = itertools.zip_longest(node.inputs, pattern_node.inputs, fillvalue=None) - for arg_value, arg_pattern in itertools.zip_longest( - node.inputs, pattern_node.inputs, fillvalue=None - ): + for arg_value, arg_pattern in to_match: # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: if arg_value is None: