From e24d503dc9261374a490d50d1effd2f13470f69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 18 Nov 2024 14:20:35 +0100 Subject: [PATCH] Convert more NodeVisitors to folding DQ/Q pass usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I9201d8bafd543204b697c7276d6929ad3aa09f25 --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/operators/op_avg_pool2d.py | 100 ++++++++++++++++++++--- backends/arm/operators/op_batch_norm.py | 5 ++ backends/arm/operators/op_conv2d.py | 48 ++++++----- backends/arm/operators/op_div.py | 6 ++ backends/arm/operators/op_max_pool2d.py | 2 +- backends/arm/process_node.py | 10 ++- backends/arm/tosa_utils.py | 60 ++------------ 8 files changed, 143 insertions(+), 90 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f16a34a211..e1c903302c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -90,6 +90,8 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.convolution.default, ] ) ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 4caaad9202..6665a99a7b 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -8,30 +8,41 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class AvgPool2dVisitor(NodeVisitor): +class AvgPool2dVisitor_0_80_BI(NodeVisitor): target = "aten.avg_pool2d.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) - def define_node( + def _build_generic_avgpool2d( self, node: torch.fx.Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, + input_zp: int, + output_zp: int, + accumulator_type, ) -> None: input_tensor = inputs[0] + kernel_size_list = inputs[1].special stride_size_list = inputs[2].special try: @@ -39,13 +50,76 @@ def define_node( except IndexError: pad_size_list = [0, 0, 0, 0] - build_avg_pool_2d_common( - node, - tosa_graph, - input_tensor, - kernel_size_list, - stride_size_list, - pad_size_list, - is_quant_node, - output, + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_tensor = inputs[0] + assert input_tensor.dtype == ts.DType.INT8 + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].zp + + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].zp + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type ) + + +@register_node_visitor +class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert ( + inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 + ), "Only FP32 and INT8 supported" + + if inputs[0].dtype == ts.DType.INT8: + super().define_node(node, tosa_graph, inputs, output, is_quant_node) + + if inputs[0].dtype == ts.DType.FP32: + accumulator_type = ts.DType.FP32 + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index d17c3a1b81..ee773949d1 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,10 @@ class BatchNormVisitor(NodeVisitor): target = "aten._native_batch_norm_legit_no_training.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index ffbeee7306..dc64e16936 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -8,16 +8,16 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale_conv_output, - get_quant_arg_downstream, - get_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -57,9 +57,6 @@ def define_node( ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs - # Currently only int8 is supported in quantized types. - actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype - # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -82,9 +79,11 @@ def define_node( dilation_attr[1], ) - input_zp = ( - get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 - ) + input_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + # int8 input requires quantization information + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].zp attr.ConvAttribute( pad=pad_attr, @@ -100,16 +99,22 @@ def define_node( # Create a zero bias tensor if not presented out_channels = weight.shape[0] bias_name = "bias" + node.name.split("default", 1)[1] + bias_type = output.dtype + if output.dtype == ts.DType.INT8: + # Conv is quantized to int8, but the TOSA operator has + # output type int32, and the bias must be the same type + # as the TOSA output type + bias_type = ts.DType.INT32 bias = tosa_graph.addConst( [out_channels], - ts.DType.INT32 if is_quant_node else output.dtype, + bias_type, [0] * out_channels, name=bias_name, ) # The output type is int32 when input type is int8. conv2d_output_name = output.name - if is_quant_node: + if output.dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) @@ -132,7 +137,7 @@ def define_node( weight_reshaped = tosa_graph.addIntermediate( weight_post_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, + weight.dtype, ) build_reshape( tosa_graph, weight.name, weight_post_shape, weight_reshaped.name @@ -157,20 +162,19 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if is_quant_node: + if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale - weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale - output_qargs = get_quant_arg_downstream(list(node.users)[0]) - + input_scale = input_qparams[0].scale + weight_scale = input_qparams[1].scale + output_qargs = get_output_qparams(node) build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, - actual_out_type, + output.dtype, input_scale, weight_scale, - output_qargs.scale, - output_qargs.zp, + output_qargs[0].scale, + output_qargs[0].zp, ) diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py index 0857e0ed32..339833c329 100644 --- a/backends/arm/operators/op_div.py +++ b/backends/arm/operators/op_div.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,11 @@ class DivVisitor(NodeVisitor): target = "aten.div.Tensor" + # Only supported for MI + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 74e33ddb02..0a4092e3a9 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import ( +from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 2d3a0c2786..3b1ea9d70f 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,10 +11,12 @@ import serializer.tosa_serializer as ts import torch import torch.fx +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_upstream, get_quantized_node_output_dtype, is_node_quantized, ) @@ -110,8 +112,10 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_arg_upstream(input_node).scale - weight_node_scale = get_quant_arg_upstream(weight_node).scale + input_qargs = get_input_qparams(consumer_node) + + input_node_scale = input_qargs[0].scale + weight_node_scale = input_qargs[1].scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 1ae319e0cd..dd28105a63 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,18 +7,13 @@ import logging import os -from typing import Any, cast +from typing import Any import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - get_quant_arg_upstream, - q_op, -) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -140,10 +135,15 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] - return ( + + if ( consumer_node.target == exir_ops.edge.aten.convolution.default - and list(consumer_node.users)[0].target == q_op - ) + and consumer_node.args[2] == node + and consumer_node.meta["val"].dtype == torch.int8 + ): + return True + + return False def is_consumer_node_depthwise_conv2d(node): @@ -159,48 +159,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def build_avg_pool_2d_common( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - input_tensor: TosaArg, - kernel_size: list, - stride: list, - padding: list, - is_quant_node: bool, - output: TosaArg, -): - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int32 when input tensor is an integer type. - accumulator_type = ts.DType.INT32 - - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - if is_quant_node: - input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_arg_downstream(list(node.users)[0]).zp - - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size, - stride=stride, - pad=padding, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) - - def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: """Returns two input nodes to 'node' in order. If 'node' only has one input, it is returned twice.