forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Empty Input Tensors and > 5 Cat Inputs
Differential Revision: D68523312 Pull Request resolved: pytorch#7855
- Loading branch information
Showing
6 changed files
with
257 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
|
||
import torch | ||
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.WARNING) | ||
|
||
|
||
class DecomposeConcatenate(ExportPass): | ||
""" | ||
XNNPACK's Concatenate operation only supports concatenation for <= 5 tensors | ||
at a time. As a result, to support concatenates with > 5 tensors, we can decompose | ||
concatenates into sequences of cats each with <= 5 tensors. | ||
Example: | ||
Before Pass: | ||
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5, t6], 1); | ||
After Pass: | ||
cat: "f32" = torch.ops.aten.cat.default([t1, t2, t3, t4, t5], 1); | ||
cat_1: "f32" = torch.ops.aten.cat.default([cat, t6], 1); | ||
""" | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
gm = graph_module | ||
for node in gm.graph.nodes: | ||
if ( | ||
node.op == "call_function" | ||
and node.target.__name__ == "aten.cat.default" | ||
): | ||
concat_args = node.args | ||
nodes_to_concat = node.args[0] | ||
if len(nodes_to_concat) <= 5: | ||
continue | ||
|
||
is_quantized = all( | ||
is_dequant(node) for node in nodes_to_concat | ||
) and all(is_quant(node) for node in node.users.keys()) | ||
|
||
# replace the cat args with the same args but only with the first 5 nodes | ||
new_concat_args = (nodes_to_concat[:5],) + concat_args[1:] | ||
node.args = new_concat_args | ||
|
||
remainder_nodes_to_concat = nodes_to_concat[5:] | ||
with gm.graph.inserting_after(node): | ||
logger.debug(f"Decomposing cat node {node}") | ||
remainder_concat_node = gm.graph.create_node( | ||
"call_function", | ||
target=exir_ops.edge.aten.cat.default, | ||
args=([],), # we will replace this remainder_nodes later | ||
kwargs=node.kwargs, | ||
) | ||
node.replace_all_uses_with(remainder_concat_node) | ||
if is_quantized: | ||
# if quantized we need to enforce the q/dq pattern for the newly inserted | ||
# concat node | ||
q_params = nodes_to_concat[0].args[1:] | ||
q_kwargs = nodes_to_concat[0].kwargs | ||
# Quantizer enforces all the inputs and output to a concat node must share | ||
# the same qparams, this means the newly inserted q/dq pair must share the | ||
# same qparams as the first quantized input in the concat node. | ||
with gm.graph.inserting_after(node): | ||
logger.debug( | ||
f"Inserting Q/DQ pair for new cat node {remainder_concat_node}" | ||
) | ||
q_node = gm.graph.create_node( | ||
"call_function", | ||
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, | ||
args=(node,) + q_params, | ||
kwargs=q_kwargs, | ||
) | ||
with gm.graph.inserting_after(q_node): | ||
dq_node = gm.graph.create_node( | ||
"call_function", | ||
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, | ||
args=(q_node,) + q_params, | ||
kwargs=q_kwargs, | ||
) | ||
remainder_concat_node.args = ( | ||
[dq_node] + remainder_nodes_to_concat, | ||
) + node.args[1:] | ||
else: | ||
remainder_concat_node.args = ( | ||
[node] + remainder_nodes_to_concat, | ||
) + node.args[1:] | ||
|
||
gm.recompile() | ||
new_gm = super().call(gm).graph_module | ||
return PassResult(new_gm, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
backends/xnnpack/test/passes/test_decompose_cat_pass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
import unittest | ||
|
||
import torch | ||
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate | ||
from executorch.backends.xnnpack.test.tester import RunPasses, Tester | ||
|
||
|
||
class TestDecomposeCatPass(unittest.TestCase): | ||
PassStage = RunPasses([DecomposeConcatenate]) | ||
cat_name = "executorch_exir_dialects_edge__ops_aten_cat_default" | ||
|
||
class Cat(torch.nn.Module): | ||
def forward(self, *args): | ||
xs = [*args] | ||
x = torch.cat(xs) | ||
return x + x # Quantize by propagation. | ||
|
||
def test_cat_gt_5(self): | ||
inputs = [ | ||
torch.randn(1, 2, 3), | ||
] | ||
for num_inputs in range(6, 10): | ||
inputs = [] | ||
for _ in range(num_inputs): | ||
inputs.append(torch.randn(1, 2, 3)) | ||
|
||
num_cats = int(len(inputs) > 5) | ||
num_cats += math.ceil((len(inputs) - 5) / 4) | ||
( | ||
Tester(self.Cat(), tuple(inputs)) | ||
.export() | ||
.to_edge() | ||
.check_count({self.cat_name: 1}) | ||
.run_passes(self.PassStage) | ||
.check_count({self.cat_name: num_cats}) | ||
.run_method_and_compare_outputs() | ||
) | ||
|
||
def test_cat_gt_10(self): | ||
inputs = [ | ||
torch.randn(1, 2, 3), | ||
] | ||
for num_inputs in [11, 16, 18]: | ||
inputs = [] | ||
for _ in range(num_inputs): | ||
inputs.append(torch.randn(1, 2, 3)) | ||
|
||
num_cats = int(len(inputs) > 5) | ||
num_cats += math.ceil((len(inputs) - 5) / 4) | ||
( | ||
Tester(self.Cat(), tuple(inputs)) | ||
.export() | ||
.to_edge() | ||
.check_count({self.cat_name: 1}) | ||
.run_passes(self.PassStage) | ||
.check_count({self.cat_name: num_cats}) | ||
.run_method_and_compare_outputs() | ||
) | ||
|
||
def test_qs8_cat_gt_5(self): | ||
inputs = [ | ||
torch.randn(1, 2, 3), | ||
] | ||
for num_inputs in range(6, 10): | ||
inputs = [] | ||
for _ in range(num_inputs): | ||
inputs.append(torch.randn(1, 2, 3)) | ||
|
||
num_cats = int(len(inputs) > 5) | ||
num_cats += math.ceil((len(inputs) - 5) / 4) | ||
( | ||
Tester(self.Cat(), tuple(inputs)) | ||
.quantize() | ||
.export() | ||
.to_edge() | ||
.check_count({self.cat_name: 1}) | ||
.run_passes(self.PassStage) | ||
.check_count({self.cat_name: num_cats}) | ||
.run_method_and_compare_outputs() | ||
) | ||
|
||
def test_qs8_cat_gt_10(self): | ||
inputs = [ | ||
torch.randn(1, 2, 3), | ||
] | ||
for num_inputs in [11, 16, 18]: | ||
inputs = [] | ||
for _ in range(num_inputs): | ||
inputs.append(torch.randn(1, 2, 3)) | ||
|
||
num_cats = int(len(inputs) > 5) | ||
num_cats += math.ceil((len(inputs) - 5) / 4) | ||
( | ||
Tester(self.Cat(), tuple(inputs)) | ||
.quantize() | ||
.export() | ||
.to_edge() | ||
.check_count({self.cat_name: 1}) | ||
.run_passes(self.PassStage) | ||
.check_count({self.cat_name: num_cats}) | ||
.run_method_and_compare_outputs() | ||
) |