Skip to content

Commit

Permalink
Support Empty Input Tensors and > 5 Cat Inputs
Browse files Browse the repository at this point in the history
Differential Revision: D68523312

Pull Request resolved: pytorch#7855
  • Loading branch information
mcr229 authored Jan 24, 2025
1 parent b1ffa1e commit b522084
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 51 deletions.
15 changes: 1 addition & 14 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,7 @@ oncall("executorch")

python_library(
name = "xnnpack_passes",
srcs = [
"__init__.py",
"channels_last_tagged_reshape_pass.py",
"conv1d_unsqueeze_pass.py",
"convert_to_linear.py",
"convert_to_sdpa.py",
"convert_to_upsample_bilinear2d.py",
"fuse_activation_pass.py",
"fuse_batch_norm_with_conv.py",
"prelu_reshape_pass.py",
"remove_getitem_op.py",
"tag_implicit_q_dq_pass.py",
"xnnpack_pass.py",
],
srcs = native.glob(["*.py"]),
deps = [
"//caffe2:torch",
"//executorch/backends/transforms:addmm_mm_to_linear",
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
ConvertToUpsampleBilinear2d,
)
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
ConstPropPass,
FuseBatchNormWithConvPass,
FuseActivationPass,
DecomposeConcatenate,
RemoveGetItemPass,
Conv1dUnsqueezePass,
PReLUReshapePass,
Expand Down
99 changes: 99 additions & 0 deletions backends/xnnpack/_passes/decompose_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class DecomposeConcatenate(ExportPass):
"""
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors
at a time. As a result, to support concatenates with > 5 tensors, we can decompose
concatenates into sequences of cats each with <= 5 tensors.
Example:
Before Pass:
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1);
After Pass:
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1);
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1);
"""

def call(self, graph_module: torch.fx.GraphModule):
gm = graph_module
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target.__name__ == "aten.cat.default"
):
concat_args = node.args
nodes_to_concat = node.args[0]
if len(nodes_to_concat) <= 5:
continue

is_quantized = all(
is_dequant(node) for node in nodes_to_concat
) and all(is_quant(node) for node in node.users.keys())

# replace the cat args with the same args but only with the first 5 nodes
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:]
node.args = new_concat_args

remainder_nodes_to_concat = nodes_to_concat[5:]
with gm.graph.inserting_after(node):
logger.debug(f"Decomposing cat node {node}")
remainder_concat_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.aten.cat.default,
args=([],), # we will replace this remainder_nodes later
kwargs=node.kwargs,
)
node.replace_all_uses_with(remainder_concat_node)
if is_quantized:
# if quantized we need to enforce the q/dq pattern for the newly inserted
# concat node
q_params = nodes_to_concat[0].args[1:]
q_kwargs = nodes_to_concat[0].kwargs
# Quantizer enforces all the inputs and output to a concat node must share
# the same qparams, this means the newly inserted q/dq pair must share the
# same qparams as the first quantized input in the concat node.
with gm.graph.inserting_after(node):
logger.debug(
f"Inserting Q/DQ pair for new cat node {remainder_concat_node}"
)
q_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(node,) + q_params,
kwargs=q_kwargs,
)
with gm.graph.inserting_after(q_node):
dq_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q_node,) + q_params,
kwargs=q_kwargs,
)
remainder_concat_node.args = (
[dq_node] + remainder_nodes_to_concat,
) + node.args[1:]
else:
remainder_concat_node.args = (
[node] + remainder_nodes_to_concat,
) + node.args[1:]

gm.recompile()
new_gm = super().call(gm).graph_module
return PassResult(new_gm, True)
4 changes: 2 additions & 2 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:

num_tensors = len(node.all_input_nodes)

if not (num_tensors >= 2 and num_tensors <= 5):
if not (num_tensors >= 2):
why(
node,
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
reason=f"only support concatenation of > 2 tensors, got {num_tensors} tensors",
)
return False

Expand Down
79 changes: 44 additions & 35 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

class TestCat(unittest.TestCase):
class Cat(torch.nn.Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, *args):
xs = [*args]
x = torch.cat(xs)
x = torch.cat(xs, dim=self.dim)
return x + x # Quantize by propagation.

def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
Expand All @@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
tester.quantize()

tester.export().check_count({"torch.ops.aten.cat": 1})
tester.dump_artifact()

if quant:
# Expect multiple quantize ops - one per input, cat, and add.
Expand Down Expand Up @@ -93,6 +96,29 @@ def test_fp16_cat4(self):
)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat5(self):
"""
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
inputs = (
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
torch.randn(2, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat_gt_5(self):
"""
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
"""
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3).to(torch.float16))
self._test_cat(self.Cat(), tuple(inputs))

def test_fp32_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat(), inputs)
Expand Down Expand Up @@ -120,6 +146,13 @@ def test_fp32_cat5(self):
)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat_gt_5(self):
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs))

def test_qs8_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)
Expand All @@ -137,46 +170,22 @@ def test_qs8_cat4(self):
)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

def test_fp32_cat_unsupported(self):
"""
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
"""
def test_qs8_cat5(self):
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(2, 2, 3),
)
(
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge_transform_and_lower()
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)

def test_fp32_cat_unsupported_legacy_mode(self):
"""
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
"""
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(6, 2, 3),
)
(
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge()
.partition()
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)
self._test_cat(self.Cat(), inputs, cat_num=5, quant=True)

def test_qs8_cat_gt_5(self):
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
Expand Down
109 changes: 109 additions & 0 deletions backends/xnnpack/test/passes/test_decompose_cat_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import unittest

import torch
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestDecomposeCatPass(unittest.TestCase):
PassStage = RunPasses([DecomposeConcatenate])
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default"

class Cat(torch.nn.Module):
def forward(self, *args):
xs = [*args]
x = torch.cat(xs)
return x + x # Quantize by propagation.

def test_cat_gt_5(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_cat_gt_10(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in [11, 16, 18]:
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_qs8_cat_gt_5(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in range(6, 10):
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.quantize()
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

def test_qs8_cat_gt_10(self):
inputs = [
torch.randn(1, 2, 3),
]
for num_inputs in [11, 16, 18]:
inputs = []
for _ in range(num_inputs):
inputs.append(torch.randn(1, 2, 3))

num_cats = int(len(inputs) > 5)
num_cats += math.ceil((len(inputs) - 5) / 4)
(
Tester(self.Cat(), tuple(inputs))
.quantize()
.export()
.to_edge()
.check_count({self.cat_name: 1})
.run_passes(self.PassStage)
.check_count({self.cat_name: num_cats})
.run_method_and_compare_outputs()
)

0 comments on commit b522084

Please sign in to comment.