diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index d03e4a1385..7309287998 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -62,6 +62,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.relu.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index babfbcfea0..a8ddf1c8f0 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -28,6 +28,7 @@ op_relu, op_repeat, op_rsqrt, + op_select, op_sigmoid, op_slice, op_squeeze, diff --git a/backends/arm/operators/op_select.py b/backends/arm/operators/op_select.py new file mode 100644 index 0000000000..6037ed000c --- /dev/null +++ b/backends/arm/operators/op_select.py @@ -0,0 +1,69 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class SelectVisitor(NodeVisitor): + target = "aten.select_copy.int" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + assert len(inputs) == 3 + input_node, dim, index = inputs + shape = input_node.shape + rank = len(shape) + + dim = dim.number % rank if dim.number < 0 else dim.number + index = index.number % rank if index.number < 0 else index.number + + # For aten.select_copy, the output will be rank[input_shape - 1] + # For TOSA rank(in) == rank(out). + # Add an intermediate with the same rank + expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank)) + expanded_shape = tosa_shape(expanded_shape, input_node.dim_order) + + output_reshaped = tosa_graph.addIntermediate( + expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype + ) + + attr_slice = ts.TosaSerializerAttribute() + + start_attr = [index if i == dim else 0 for i in input_node.dim_order] + size_attr = [ + 1 if i == dim else input_node.shape[i] for i in input_node.dim_order + ] + + attr_slice.SliceAttribute(start_attr, size_attr) + + tosa_graph.addOperator( + TosaOp.Op().SLICE, [input_node.name], [output_reshaped.name], attr_slice + ) + + # Reshape back to original rank of output. + build_reshape(tosa_graph, output_reshaped.name, output.shape, output.name) diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index a490991693..f91df1398e 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -34,6 +34,8 @@ # torch.ops.aten.view_as_real.default, # torch.ops.aten.view_as_real_copy.default, torch.ops.aten.view_copy.default, + torch.ops.aten.select.int, + torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor, # 'concat' should be handled separately as it has a sequence of inputs and diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py new file mode 100644 index 0000000000..fdb2fa1463 --- /dev/null +++ b/backends/arm/test/ops/test_select.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +test_data_t = tuple[torch.Tensor, int, int] + +test_data_suite: list[tuple[test_data_t]] = [ + # (test_data, dim, index) + ((torch.zeros(5, 3, 20), -1, 0),), + ((torch.zeros(5, 3, 20), 0, -1),), + ((torch.zeros(5, 3, 20), 0, 4),), + ((torch.ones(10, 10, 10), 0, 2),), + ((torch.rand(5, 3, 20, 2), 0, 2),), + ((torch.rand(10, 10) - 0.5, 0, 0),), + ((torch.randn(10) + 10, 0, 1),), + ((torch.randn(10) - 10, 0, 2),), + ((torch.arange(-16, 16, 0.2), 0, 1),), +] + + +class TestSelect(unittest.TestCase): + class SelectCopy(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim: int, index: int): + return torch.select_copy(x, dim=dim, index=index) + + class SelectInt(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim: int, index: int): + return torch.select(x, dim=dim, index=index) + + def _test_select_tosa_MI_pipeline( + self, + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute + ), + ) + .export() + .check([export_target]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_select_tosa_BI_pipeline( + self, + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute + ), + ) + .quantize() + .export() + .check([export_target]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .dump_artifact() + .dump_operator_distribution() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_select_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: test_data_t, + export_target: str, + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check([export_target]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .dump_artifact() + .dump_operator_distribution() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_select_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t, export_target: str + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + self._test_select_ethos_BI_pipeline( + common.get_u55_compile_spec(permute_memory_to_nhwc=permute), + module, + test_data, + export_target, + ) + + def _test_select_tosa_u85_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t, export_target: str + ): + # For 4D tensors, do not permute to NHWC + permute = False if len(test_data[0].shape) == 4 else True + self._test_select_ethos_BI_pipeline( + common.get_u85_compile_spec(permute_memory_to_nhwc=permute), + module, + test_data, + export_target, + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_MI(self, test_data: test_data_t): + self._test_select_tosa_MI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_MI(self, test_data: test_data_t): + self._test_select_tosa_MI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_BI(self, test_data: test_data_t): + self._test_select_tosa_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_BI(self, test_data: test_data_t): + self._test_select_tosa_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_u55_BI(self, test_data: test_data_t): + self._test_select_tosa_u55_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_u55_BI(self, test_data: test_data_t): + self._test_select_tosa_u55_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_copy_tosa_u85_BI(self, test_data: test_data_t): + self._test_select_tosa_u85_BI_pipeline( + self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" + ) + + @parameterized.expand(test_data_suite) + def test_select_int_tosa_u85_BI(self, test_data: test_data_t): + self._test_select_tosa_u85_BI_pipeline( + self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" + )