From 11b2fcb7ba065700c2efce895c834900be242317 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 23 Jul 2024 21:29:44 -0700 Subject: [PATCH] Add sub operator for Arm backend (#4074) Summary: Implemented node visitor, quantizer, and test. TOSA MI and BI passes, U55 BI fails (expectedly) on Vela compilation. Refactored code shared between sub and add. Change-Id: Ifc9fc4ae083f3ed868ad763e4301e5fe87468a25 Pull Request resolved: https://github.com/pytorch/executorch/pull/4074 Reviewed By: mergennachin Differential Revision: D59259308 Pulled By: digantdesai fbshipit-source-id: ce10e9b1a583e6374e5f1c5815dc11c0d0e7aa5b --- backends/arm/arm_partitioner.py | 1 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_add.py | 79 ++++-------- backends/arm/operators/op_sub.py | 69 +++++++++++ backends/arm/quantizer/arm_quantizer.py | 2 + backends/arm/quantizer/arm_quantizer_utils.py | 50 +++++++- .../quantization_annotation/__init__.py | 1 + .../quantization_annotation/add_annotator.py | 40 ++---- .../quantization_annotation/sub_annotator.py | 46 +++++++ backends/arm/test/ops/test_add.py | 7 -- backends/arm/test/ops/test_sub.py | 117 ++++++++++++++++++ backends/arm/tosa_quant_utils.py | 64 ++++++++++ backends/arm/tosa_utils.py | 62 +++++++++- 13 files changed, 443 insertions(+), 96 deletions(-) create mode 100644 backends/arm/operators/op_sub.py create mode 100644 backends/arm/quantizer/quantization_annotation/sub_annotator.py create mode 100644 backends/arm/test/ops/test_sub.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index d375a46c16..54cfafcc9b 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -47,6 +47,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index bc7785da59..79c507816d 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -20,5 +20,6 @@ op_quant, op_sigmoid, op_softmax, + op_sub, op_view, ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 76f9f996b8..33c0c49744 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -5,19 +5,17 @@ from typing import List +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + import serializer.tosa_serializer as ts -import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale_from_int32, - build_rescale_to_int32, -) -from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape from serializer.tosa_serializer import TosaOp +from torch.fx import Node @register_node_visitor @@ -29,75 +27,44 @@ def __init__(self, *args): def define_node( self, - node: torch.fx.Node, + node: Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, is_quant_node: bool, ) -> None: if is_quant_node: - # Single input or not - if len(node.all_input_nodes) == 1: - input_node_A = node.all_input_nodes[0] - input_node_B = node.all_input_nodes[0] - else: - input_node_A, input_node_B = node.all_input_nodes - - # Get input scale_factor and zero_points for A, B - input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A) - input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B) - - # Scale the int8 quantized input to a common scale in the integer - # domain. - min_scale = min(input_A_scale.number, input_B_scale.number) - inputA_rescale_scale = input_A_scale.number / min_scale - inputB_rescale_scale = input_B_scale.number / min_scale - - input_A.shape = tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tosa_shape(input_B.shape, input_B.dim_order) - broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape) + input_nodes = tutils.get_two_inputs(node) - input_A_rescaled_to_int32 = build_rescale_to_int32( - tosa_graph, - input_A, - input_A_zp.number, - inputA_rescale_scale, + # Rescale inputs to 32 bit + rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( + input_nodes, tosa_graph ) - input_B_rescaled_to_int32 = build_rescale_to_int32( - tosa_graph, - input_B, - input_B_zp.number, - inputB_rescale_scale, + # Preapre sub output tensor + broadcasted_shape = tutils.broadcast_shapes( + rescaled_inputs[0].shape, rescaled_inputs[0].shape ) + add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - ## Do the INT32 Add - add_res = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + # Do the INT32 Add tosa_graph.addOperator( TosaOp.Op().ADD, [ - input_A_rescaled_to_int32.name, - input_B_rescaled_to_int32.name, + rescaled_inputs[0].name, + rescaled_inputs[1].name, ], - [add_res.name], + [add_output.name], None, ) - # Output - output_node = list(node.users)[0] - _, output_scale, output_zp, _, _, _ = getNodeArgs(output_node) - output_rescale_scale = min_scale / output_scale.number - - # Rescale Back to INT8 - build_rescale_from_int32( - tosa_graph, - add_res.name, - output.name, - output_zp.number, - output_rescale_scale, - ) + # Scale output back to 8 bit + tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) else: # FP32 Add lowering tosa_graph.addOperator( - TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], [output.name], None + TosaOp.Op().ADD, + [inputs[0].name, inputs[1].name], + [output.name], + None, ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py new file mode 100644 index 0000000000..3dc1519f37 --- /dev/null +++ b/backends/arm/operators/op_sub.py @@ -0,0 +1,69 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class SubVisitor(NodeVisitor): + target = "aten.sub.Tensor" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if is_quant_node: + input_nodes = tutils.get_two_inputs(node) + + # Rescale inputs to 32 bit + rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( + input_nodes, tosa_graph + ) + + # Preapre sub output tensor + broadcasted_shape = tutils.broadcast_shapes( + rescaled_inputs[0].shape, rescaled_inputs[0].shape + ) + sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + + # Do the INT32 Sub + tosa_graph.addOperator( + TosaOp.Op().SUB, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [sub_output.name], + ) + + # Scale output back to 8 bit + tqutils.rescale_node_back_to_int8(node, sub_output, scale, tosa_graph) + else: + # FP32 Sub lowering + tosa_graph.addOperator( + TosaOp.Op().SUB, + [inputs[0].name, inputs[1].name], + [output.name], + None, + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index b094067fe0..3e1aceefe1 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -66,6 +66,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], + "sub": [[torch.sub]], } return copy.deepcopy(supported_operators) @@ -254,6 +255,7 @@ class ArmQuantizer(Quantizer): "adaptive_avg_pool2d", "max_pool2d", "add", + "sub", "mul", "sigmoid", ] diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index bfb0dd7b23..ee2844e668 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -12,6 +12,7 @@ from typing import Callable, cast, List import torch +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix @@ -54,7 +55,53 @@ def mark_nodes_as_annotated(nodes: List[Node]) -> None: node.meta["quantization_annotation"]._annotated = True -def is_input_large_scalar(node: Node, gm: GraphModule) -> bool: +def get_shared_qspec( + node: Node, gm: GraphModule, quantization_config: QuantizationConfig +): + """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs + and output to the parameter 'node'. + Parameters: + node: a node with two inputs that should share Quantization parameters. + gm: The GraphModule containing the node. Used to inspect global graph features. + quantization_config : a QuantizationConfig with the input QuantizationSpec to share + Returns: + input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to + the correct QuantizationSpec. + shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec. + + Both outputs are None if one of the inputs is a node that can't be quantized. + """ + input_act0 = node.args[0] + input_act1 = node.args[1] + + input_act_qspec = quantization_config.get_input_act_qspec() + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) + + input_qspec_map = {} + if isinstance(input_act0, Node): + if not is_input_ok_for_quantization(input_act0, gm): + return None, None + input_qspec_map[input_act0] = input_act_qspec + + if isinstance(input_act1, Node): + if not is_input_ok_for_quantization(input_act1, gm): + return None, None + if input_act0 is not input_act1: + input_qspec_map[input_act1] = shared_with_input0_qspec + return input_qspec_map, shared_with_input0_qspec + + +def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): + """Check if an input can be quantized. The input can not be quantized if: + - The node does not output a float tensor or, + - The node outputs a large scalar. + """ + return not ( + is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm) + ) + + +def is_input_large_scalar(node: Node, gm: GraphModule): """Check if input is a large scalar value. So that we can skip quantization for the node since histc op (in HistogramObserver) only works for values up to certain upper bound """ @@ -142,6 +189,7 @@ def convert_scalars_to_attrs(model: GraphModule) -> GraphModule: """ targeted_ops = [ torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Tensor, ] for n in model.graph.nodes: diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 5e372fe5b7..d162bfd479 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -54,4 +54,5 @@ def decorator(annotator: AnnotatorType): max_pool2d_annotator, mul_annotator, sigmoid_annotator, + sub_annotator, ) diff --git a/backends/arm/quantizer/quantization_annotation/add_annotator.py b/backends/arm/quantizer/quantization_annotation/add_annotator.py index f01301ea39..2926e92f24 100644 --- a/backends/arm/quantizer/quantization_annotation/add_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/add_annotator.py @@ -12,10 +12,7 @@ from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.quantization_annotation import register_annotator from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) +from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import Node from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -37,32 +34,13 @@ def _annotate_add( if arm_quantizer_utils.is_annotated(add_node): continue - input_act0 = add_node.args[0] - input_act_qspec = quantization_config.get_input_act_qspec() - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, add_node)) - - input_qspec_map = {} - if isinstance(input_act0, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act0, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act0): - continue - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = add_node.args[1] - if isinstance(input_act1, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act1, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act1): - continue - if input_act0 is not input_act1: - input_qspec_map[input_act1] = shared_with_input0_qspec - else: - input_qspec_map[input_act1] = input_act_qspec - - add_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=shared_with_input0_qspec, - _annotated=True, + input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( + add_node, gm, quantization_config ) + if input_qspec_map is not None: + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/sub_annotator.py b/backends/arm/quantizer/quantization_annotation/sub_annotator.py new file mode 100644 index 0000000000..4686d480ed --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/sub_annotator.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import operator +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.fx import GraphModule, Node +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +@register_annotator("sub") +def _annotate_sub( + gm: GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + sub_partitions = get_source_partitions( + gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn + ) + sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values())) + annotated_partitions = [] + for sub_partition in sub_partitions: + annotated_partitions.append(sub_partition.nodes) + sub_node = sub_partition.output_nodes[0] + if arm_quantizer_utils.is_annotated(sub_node): + continue + + input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( + sub_node, gm, quantization_config + ) + if input_qspec_map is not None: + sub_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + return annotated_partitions diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 698e96466d..622d811822 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging import unittest from typing import Tuple @@ -16,9 +15,6 @@ from executorch.exir import EdgeCompileConfig from parameterized import parameterized -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - class TestSimpleAdd(unittest.TestCase): class Add(torch.nn.Module): @@ -30,9 +26,6 @@ class Add(torch.nn.Module): (torch.ones(1, 3, 4, 2),), ] - def __init__(self): - super().__init__() - def forward(self, x): return x + x diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py new file mode 100644 index 0000000000..2ae7c3ab36 --- /dev/null +++ b/backends/arm/test/ops/test_sub.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class TestSimpleSub(unittest.TestCase): + class Sub(torch.nn.Module): + test_parameters = [ + (torch.ones(5),), + (3 * torch.ones(8),), + (10 * torch.randn(8),), + ] + + def forward(self, x): + return x - x + + class Sub2(torch.nn.Module): + test_parameters = [ + (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def forward(self, x, y): + return x - y + + def _test_sub_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.sub.Tensor": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["torch.ops.aten.sub.Tensor"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_sub_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sub.Tensor": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_sub_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sub.Tensor": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Sub.test_parameters) + def test_sub_tosa_MI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_sub_tosa_MI_pipeline(self.Sub(), test_data) + + @parameterized.expand(Sub.test_parameters) + def test_sub_tosa_BI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_sub_tosa_BI_pipeline(self.Sub(), test_data) + + # Expected to fail since RESCALE cannot be fused with SUB in Vela. + @parameterized.expand(Sub.test_parameters) + @unittest.expectedFailure + def test_sub_u55_BI(self, test_data: torch.Tensor): + test_data = (test_data,) + self._test_sub_u55_BI_pipeline(self.Sub(), test_data) + + @parameterized.expand(Sub2.test_parameters) + def test_sub2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_sub_tosa_MI_pipeline(self.Sub2(), test_data) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index c29279a661..55649f4bef 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -15,6 +15,7 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp, TosaSerializerTensor +from torch.fx import Node q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default @@ -252,6 +253,69 @@ def build_rescale_from_int32( return +def rescale_nodes_to_int32( + nodes: list[Node], tosa_graph: ts.TosaSerializer +) -> tuple[list[TosaSerializerTensor], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + """ + + tensors = [TosaArg(node.args[0]) for node in nodes] + + # Reshape tensor according to tosa dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + qargs = [get_quant_node_args(node) for node in nodes] + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.scale for qarg in qargs]) + scales = [qarg.scale / min_scale for qarg in qargs] + + rescaled_nodes: list[TosaSerializerTensor] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, + tensor, + qarg.zp, + scale, + ) + ) + return rescaled_nodes, min_scale + + +def rescale_node_back_to_int8( + node: Node, + last_tensor: TosaSerializerTensor, + scale: float, + tosa_graph: ts.TosaSerializer, +): + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + tosa_graph: the tosa_graph to manipulate. + """ + qargs_out = get_quant_node_args(list(node.users)[0]) + output_rescale_scale = scale / qargs_out.scale + + # Rescale Back to INT8 + build_rescale_from_int32( + tosa_graph, + last_tensor.name, + node.name, + qargs_out.zp, + output_rescale_scale, + ) + + """ Creates a TOSA rescale op based on conv2d parameters. """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index f31819f3d3..4dc0204516 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -21,6 +21,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp +from torch.fx import Node logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -88,10 +89,45 @@ def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): return reshape_res -def getNodeArgs(node): +# Helper transpose function to match TOSA's shape requirements +# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes: +# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d +def transpose_helper(tosa_fb, input, new_order, out_dtype): + # Check new_order's length is equal to input rank + assert len(input.shape) == len(new_order), "Wrong shape order length" + + # Check no duplications + assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers" + + # Check all dims are valid + for idx in new_order: + if idx < 0: + assert True, "Negative dim number" + elif idx >= len(input.shape): + assert True, "Dim is greater than input rank" + + input_shape_transpoed = [input.shape[i] for i in new_order] + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(new_order) + input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype) + tosa_fb.addOperator( + TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr + ) + return input_transposed + + +def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args] +def get_input_tensor(node: Node) -> TosaArg: + return TosaArg(node.args[0]) + + +def get_output_node(node: Node) -> Node: + return list(node.users)[0] + + # Helper function to do broadcasting # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_broadcasting def broadcast_shapes(shape1, shape2): @@ -220,6 +256,30 @@ def build_avg_pool_2d_common( ) +def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: + """Returns two input nodes to 'node' in order. If 'node' only has one input, + it is returned twice. + + Fails if there are no input nodes. + Fails if there are >2 input nodes and 'check' is True, + """ + + num_inputs = len(node.all_input_nodes) + assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." + + input1 = node.all_input_nodes[0] + if num_inputs == 1: + input2 = node.all_input_nodes[0] + else: + input2 = node.all_input_nodes[1] + if check: + assert ( + num_inputs <= 2 + ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." + + return input1, input2 + + def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order])