diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index d84fe40d99..e6d07610c8 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -10,8 +10,8 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.arm.operators.op_common import build_avg_pool_2d_common from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common @register_node_visitor diff --git a/backends/arm/operators/op_common.py b/backends/arm/operators/op_common.py deleted file mode 100644 index eadf00c294..0000000000 --- a/backends/arm/operators/op_common.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args -from serializer.tosa_serializer import TosaOp - - -def build_avg_pool_2d_common( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - input_tensor: TosaArg, - kernel_size: list, - stride: list, - padding: list, - is_quant_node: bool, - output: TosaArg, -): - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int32 when input tensor is an integer type. - accumulator_type = ts.DType.INT32 - - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - if is_quant_node: - input_zp = get_quant_node_args(node.args[0]).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp - - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size, - stride=stride, - pad=padding, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_mean_dim.py b/backends/arm/operators/op_mean_dim.py index 5e8e3d74c0..20e1b2b8d7 100644 --- a/backends/arm/operators/op_mean_dim.py +++ b/backends/arm/operators/op_mean_dim.py @@ -10,8 +10,8 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.arm.operators.op_common import build_avg_pool_2d_common from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common @register_node_visitor diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 68d090653a..a692b3a270 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -6,11 +6,12 @@ import logging import os -import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils - import numpy as np import serializer.tosa_serializer as ts +import torch from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, q_op from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp @@ -158,7 +159,7 @@ def is_bias_node_for_addmm(node): # consumer node is addmm is_rank2_linear_bias = ( consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == tosa_quant_utils.q_op + and list(consumer_node.users)[0].target == q_op ) # rank>2 linear layers @@ -170,7 +171,7 @@ def is_bias_node_for_addmm(node): ): consumer_consumer_node = list(consumer_node.users)[0] is_rank_greater_than_2_linear_bias = ( - list(consumer_consumer_node.users)[0].target == tosa_quant_utils.q_op + list(consumer_consumer_node.users)[0].target == q_op ) return is_rank2_linear_bias or is_rank_greater_than_2_linear_bias @@ -189,3 +190,45 @@ def is_consumer_node_depthwise_conv2d(node): return True return False + + +def build_avg_pool_2d_common( + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + input_tensor: TosaArg, + kernel_size: list, + stride: list, + padding: list, + is_quant_node: bool, + output: TosaArg, +): + accumulator_type = input_tensor.dtype + + if is_quant_node: + # Accumulator type always is int32 when input tensor is an integer type. + accumulator_type = ts.DType.INT32 + + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + if is_quant_node: + input_zp = get_quant_node_args(node.args[0]).zp + output_zp = get_quant_node_args(list(node.users)[0]).zp + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size, + stride=stride, + pad=padding, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, + )