forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sub operator for Arm backend (pytorch#4074)
Summary: Implemented node visitor, quantizer, and test. TOSA MI and BI passes, U55 BI fails (expectedly) on Vela compilation. Refactored code shared between sub and add. Change-Id: Ifc9fc4ae083f3ed868ad763e4301e5fe87468a25 Pull Request resolved: pytorch#4074 Reviewed By: mergennachin Differential Revision: D59259308 Pulled By: digantdesai fbshipit-source-id: ce10e9b1a583e6374e5f1c5815dc11c0d0e7aa5b
- Loading branch information
1 parent
47d309a
commit 11b2fcb
Showing
13 changed files
with
443 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,5 +20,6 @@ | |
op_quant, | ||
op_sigmoid, | ||
op_softmax, | ||
op_sub, | ||
op_view, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2023-2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import executorch.backends.arm.tosa_quant_utils as tqutils | ||
import executorch.backends.arm.tosa_utils as tutils | ||
|
||
import serializer.tosa_serializer as ts | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
from serializer.tosa_serializer import TosaOp | ||
from torch.fx import Node | ||
|
||
|
||
@register_node_visitor | ||
class SubVisitor(NodeVisitor): | ||
target = "aten.sub.Tensor" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
if is_quant_node: | ||
input_nodes = tutils.get_two_inputs(node) | ||
|
||
# Rescale inputs to 32 bit | ||
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( | ||
input_nodes, tosa_graph | ||
) | ||
|
||
# Preapre sub output tensor | ||
broadcasted_shape = tutils.broadcast_shapes( | ||
rescaled_inputs[0].shape, rescaled_inputs[0].shape | ||
) | ||
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) | ||
|
||
# Do the INT32 Sub | ||
tosa_graph.addOperator( | ||
TosaOp.Op().SUB, | ||
[ | ||
rescaled_inputs[0].name, | ||
rescaled_inputs[1].name, | ||
], | ||
[sub_output.name], | ||
) | ||
|
||
# Scale output back to 8 bit | ||
tqutils.rescale_node_back_to_int8(node, sub_output, scale, tosa_graph) | ||
else: | ||
# FP32 Sub lowering | ||
tosa_graph.addOperator( | ||
TosaOp.Op().SUB, | ||
[inputs[0].name, inputs[1].name], | ||
[output.name], | ||
None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
backends/arm/quantizer/quantization_annotation/sub_annotator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import itertools | ||
import operator | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
from executorch.backends.arm.quantizer import arm_quantizer_utils | ||
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator | ||
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig | ||
from torch.ao.quantization.quantizer import QuantizationAnnotation | ||
from torch.fx import GraphModule, Node | ||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions | ||
|
||
|
||
@register_annotator("sub") | ||
def _annotate_sub( | ||
gm: GraphModule, | ||
quantization_config: QuantizationConfig, | ||
filter_fn: Optional[Callable[[Node], bool]] = None, | ||
) -> Optional[List[List[Node]]]: | ||
sub_partitions = get_source_partitions( | ||
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn | ||
) | ||
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values())) | ||
annotated_partitions = [] | ||
for sub_partition in sub_partitions: | ||
annotated_partitions.append(sub_partition.nodes) | ||
sub_node = sub_partition.output_nodes[0] | ||
if arm_quantizer_utils.is_annotated(sub_node): | ||
continue | ||
|
||
input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( | ||
sub_node, gm, quantization_config | ||
) | ||
if input_qspec_map is not None: | ||
sub_node.meta["quantization_annotation"] = QuantizationAnnotation( | ||
input_qspec_map=input_qspec_map, | ||
output_qspec=output_qspec, | ||
_annotated=True, | ||
) | ||
return annotated_partitions |
Oops, something went wrong.