From b522084d40f7af46d7ee06888572a3d9b6fe532d Mon Sep 17 00:00:00 2001 From: Max Ren <40742183+mcr229@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:58:35 -0800 Subject: [PATCH] Support Empty Input Tensors and > 5 Cat Inputs Differential Revision: D68523312 Pull Request resolved: https://github.com/pytorch/executorch/pull/7855 --- backends/xnnpack/_passes/TARGETS | 15 +-- backends/xnnpack/_passes/__init__.py | 2 + backends/xnnpack/_passes/decompose_cat.py | 99 ++++++++++++++++ .../partition/config/generic_node_configs.py | 4 +- backends/xnnpack/test/ops/test_cat.py | 79 +++++++------ .../test/passes/test_decompose_cat_pass.py | 109 ++++++++++++++++++ 6 files changed, 257 insertions(+), 51 deletions(-) create mode 100644 backends/xnnpack/_passes/decompose_cat.py create mode 100644 backends/xnnpack/test/passes/test_decompose_cat_pass.py diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index 6bc3742abe..a199e1aab0 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -4,20 +4,7 @@ oncall("executorch") python_library( name = "xnnpack_passes", - srcs = [ - "__init__.py", - "channels_last_tagged_reshape_pass.py", - "conv1d_unsqueeze_pass.py", - "convert_to_linear.py", - "convert_to_sdpa.py", - "convert_to_upsample_bilinear2d.py", - "fuse_activation_pass.py", - "fuse_batch_norm_with_conv.py", - "prelu_reshape_pass.py", - "remove_getitem_op.py", - "tag_implicit_q_dq_pass.py", - "xnnpack_pass.py", - ], + srcs = native.glob(["*.py"]), deps = [ "//caffe2:torch", "//executorch/backends/transforms:addmm_mm_to_linear", diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 00e1ba0358..36a7833dca 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -17,6 +17,7 @@ from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( ConvertToUpsampleBilinear2d, ) +from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( FuseBatchNormWithConvPass, @@ -63,6 +64,7 @@ def __init__( ConstPropPass, FuseBatchNormWithConvPass, FuseActivationPass, + DecomposeConcatenate, RemoveGetItemPass, Conv1dUnsqueezePass, PReLUReshapePass, diff --git a/backends/xnnpack/_passes/decompose_cat.py b/backends/xnnpack/_passes/decompose_cat.py new file mode 100644 index 0000000000..b9057c43e1 --- /dev/null +++ b/backends/xnnpack/_passes/decompose_cat.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class DecomposeConcatenate(ExportPass): + """ + XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors + at a time. As a result, to support concatenates with > 5 tensors, we can decompose + concatenates into sequences of cats each with <= 5 tensors. + + Example: + Before Pass: + cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1); + + After Pass: + cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1); + cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1); + """ + + def call(self, graph_module: torch.fx.GraphModule): + gm = graph_module + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target.__name__ == "aten.cat.default" + ): + concat_args = node.args + nodes_to_concat = node.args[0] + if len(nodes_to_concat) <= 5: + continue + + is_quantized = all( + is_dequant(node) for node in nodes_to_concat + ) and all(is_quant(node) for node in node.users.keys()) + + # replace the cat args with the same args but only with the first 5 nodes + new_concat_args = (nodes_to_concat[:5],) + concat_args[1:] + node.args = new_concat_args + + remainder_nodes_to_concat = nodes_to_concat[5:] + with gm.graph.inserting_after(node): + logger.debug(f"Decomposing cat node {node}") + remainder_concat_node = gm.graph.create_node( + "call_function", + target=exir_ops.edge.aten.cat.default, + args=([],), # we will replace this remainder_nodes later + kwargs=node.kwargs, + ) + node.replace_all_uses_with(remainder_concat_node) + if is_quantized: + # if quantized we need to enforce the q/dq pattern for the newly inserted + # concat node + q_params = nodes_to_concat[0].args[1:] + q_kwargs = nodes_to_concat[0].kwargs + # Quantizer enforces all the inputs and output to a concat node must share + # the same qparams, this means the newly inserted q/dq pair must share the + # same qparams as the first quantized input in the concat node. + with gm.graph.inserting_after(node): + logger.debug( + f"Inserting Q/DQ pair for new cat node {remainder_concat_node}" + ) + q_node = gm.graph.create_node( + "call_function", + target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(node,) + q_params, + kwargs=q_kwargs, + ) + with gm.graph.inserting_after(q_node): + dq_node = gm.graph.create_node( + "call_function", + target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(q_node,) + q_params, + kwargs=q_kwargs, + ) + remainder_concat_node.args = ( + [dq_node] + remainder_nodes_to_concat, + ) + node.args[1:] + else: + remainder_concat_node.args = ( + [node] + remainder_nodes_to_concat, + ) + node.args[1:] + + gm.recompile() + new_gm = super().call(gm).graph_module + return PassResult(new_gm, True) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 9bee4925b9..dbcb5c9203 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: num_tensors = len(node.all_input_nodes) - if not (num_tensors >= 2 and num_tensors <= 5): + if not (num_tensors >= 2): why( node, - reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors", + reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors", ) return False diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 9a7adaeb0f..377bf62aa7 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -14,9 +14,13 @@ class TestCat(unittest.TestCase): class Cat(torch.nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + def forward(self, *args): xs = [*args] - x = torch.cat(xs) + x = torch.cat(xs, dim=self.dim) return x + x # Quantize by propagation. def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): @@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): tester.quantize() tester.export().check_count({"torch.ops.aten.cat": 1}) - tester.dump_artifact() if quant: # Expect multiple quantize ops - one per input, cat, and add. @@ -93,6 +96,29 @@ def test_fp16_cat4(self): ) self._test_cat(self.Cat(), inputs) + def test_fp16_cat5(self): + """ + Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. + """ + inputs = ( + torch.randn(1, 2, 3).to(torch.float16), + torch.randn(3, 2, 3).to(torch.float16), + torch.randn(2, 2, 3).to(torch.float16), + torch.randn(5, 2, 3).to(torch.float16), + torch.randn(5, 2, 3).to(torch.float16), + ) + self._test_cat(self.Cat(), inputs) + + def test_fp16_cat_gt_5(self): + """ + Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. + """ + for num_inputs in range(6, 10): + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3).to(torch.float16)) + self._test_cat(self.Cat(), tuple(inputs)) + def test_fp32_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) self._test_cat(self.Cat(), inputs) @@ -120,6 +146,13 @@ def test_fp32_cat5(self): ) self._test_cat(self.Cat(), inputs) + def test_fp32_cat_gt_5(self): + for num_inputs in range(6, 10): + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + self._test_cat(self.Cat(), tuple(inputs)) + def test_qs8_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) self._test_cat(self.Cat(), inputs, cat_num=2, quant=True) @@ -137,46 +170,22 @@ def test_qs8_cat4(self): ) self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) - def test_fp32_cat_unsupported(self): - """ - XNNPACK only supports concatenating up to 4 values, so it should not delegate here. - """ + def test_qs8_cat5(self): inputs = ( torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3), torch.randn(5, 2, 3), - torch.randn(1, 2, 3), - torch.randn(2, 2, 3), - ) - ( - Tester(self.Cat(), inputs) - .export() - .check_count({"torch.ops.aten.cat": 1}) - .to_edge_transform_and_lower() - .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) - ) - - def test_fp32_cat_unsupported_legacy_mode(self): - """ - XNNPACK only supports concatenating up to 5 values, so it should not delegate here. - """ - inputs = ( - torch.randn(1, 2, 3), - torch.randn(3, 2, 3), - torch.randn(2, 2, 3), torch.randn(5, 2, 3), - torch.randn(1, 2, 3), - torch.randn(6, 2, 3), - ) - ( - Tester(self.Cat(), inputs) - .export() - .check_count({"torch.ops.aten.cat": 1}) - .to_edge() - .partition() - .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) ) + self._test_cat(self.Cat(), inputs, cat_num=5, quant=True) + + def test_qs8_cat_gt_5(self): + for num_inputs in range(6, 10): + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True) class CatNegativeDim(torch.nn.Module): def __init__(self): diff --git a/backends/xnnpack/test/passes/test_decompose_cat_pass.py b/backends/xnnpack/test/passes/test_decompose_cat_pass.py new file mode 100644 index 0000000000..beb1761aec --- /dev/null +++ b/backends/xnnpack/test/passes/test_decompose_cat_pass.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import unittest + +import torch +from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate +from executorch.backends.xnnpack.test.tester import RunPasses, Tester + + +class TestDecomposeCatPass(unittest.TestCase): + PassStage = RunPasses([DecomposeConcatenate]) + cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default" + + class Cat(torch.nn.Module): + def forward(self, *args): + xs = [*args] + x = torch.cat(xs) + return x + x # Quantize by propagation. + + def test_cat_gt_5(self): + inputs = [ + torch.randn(1, 2, 3), + ] + for num_inputs in range(6, 10): + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + + num_cats = int(len(inputs) > 5) + num_cats += math.ceil((len(inputs) - 5) / 4) + ( + Tester(self.Cat(), tuple(inputs)) + .export() + .to_edge() + .check_count({self.cat_name: 1}) + .run_passes(self.PassStage) + .check_count({self.cat_name: num_cats}) + .run_method_and_compare_outputs() + ) + + def test_cat_gt_10(self): + inputs = [ + torch.randn(1, 2, 3), + ] + for num_inputs in [11, 16, 18]: + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + + num_cats = int(len(inputs) > 5) + num_cats += math.ceil((len(inputs) - 5) / 4) + ( + Tester(self.Cat(), tuple(inputs)) + .export() + .to_edge() + .check_count({self.cat_name: 1}) + .run_passes(self.PassStage) + .check_count({self.cat_name: num_cats}) + .run_method_and_compare_outputs() + ) + + def test_qs8_cat_gt_5(self): + inputs = [ + torch.randn(1, 2, 3), + ] + for num_inputs in range(6, 10): + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + + num_cats = int(len(inputs) > 5) + num_cats += math.ceil((len(inputs) - 5) / 4) + ( + Tester(self.Cat(), tuple(inputs)) + .quantize() + .export() + .to_edge() + .check_count({self.cat_name: 1}) + .run_passes(self.PassStage) + .check_count({self.cat_name: num_cats}) + .run_method_and_compare_outputs() + ) + + def test_qs8_cat_gt_10(self): + inputs = [ + torch.randn(1, 2, 3), + ] + for num_inputs in [11, 16, 18]: + inputs = [] + for _ in range(num_inputs): + inputs.append(torch.randn(1, 2, 3)) + + num_cats = int(len(inputs) > 5) + num_cats += math.ceil((len(inputs) - 5) / 4) + ( + Tester(self.Cat(), tuple(inputs)) + .quantize() + .export() + .to_edge() + .check_count({self.cat_name: 1}) + .run_passes(self.PassStage) + .check_count({self.cat_name: num_cats}) + .run_method_and_compare_outputs() + )