diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 5261f71e60..6098332ea4 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -65,6 +65,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 8fd4cbc6e8..e754eba351 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -34,6 +34,7 @@ op_squeeze, op_sub, op_sum, + op_tanh, op_transpose, op_unsqueeze, op_view, diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py new file mode 100644 index 0000000000..20f343a7f1 --- /dev/null +++ b/backends/arm/operators/op_tanh.py @@ -0,0 +1,86 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import numpy as np + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_quant_utils import ( + dequantize_value, + get_quant_node_args, + QuantArgs, + quantize_value, +) +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class TanhVisitor(NodeVisitor): + target = "aten.tanh.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + assert len(node.all_input_nodes) == 1 + + if is_quant_node: + # Assume quantized input is 8 bit. + assert len(node.users) == 1 + + # Create attribute for 8 bit table lookup. + input_node = node.all_input_nodes[0] + in_quantargs = get_quant_node_args(input_node) + output_node = list(node.users)[0] + out_quantargs = get_quant_node_args(output_node) + + table = tanh_table_8bit(in_quantargs, out_quantargs) + table_attr = ts.TosaSerializerAttribute() + table_attr.TableAttribute(table) + + tosa_graph.addOperator( + TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr + ) + else: + tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) + + +def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): + """ + Returns a table mapping 256 entries to tanh([qmin,qmax]) + Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh + """ + + def tanh(x): + # Convert quantized input to floating point tanh input space. + v = dequantize_value(x, in_quantargs) + # Compute tanh. + v = np.exp(-2.0 * v) + v = (1.0 - v) / (1.0 + v) + + # Convert tanh output back to quantized space. + return quantize_value(v, out_quantargs) + + return [ + tanh(x) + for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) + ] diff --git a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py index e2a1398018..544d4af9f4 100644 --- a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py @@ -41,6 +41,7 @@ def _annotate_one_to_one( torch.ops.aten.reciprocal.default, torch.ops.aten.rsqrt.default, torch.ops.aten.sigmoid.default, + torch.ops.aten.tanh.default, ) for node in gm.graph.nodes: if node.op != "call_function" or node.target not in one_to_one_ops: diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py new file mode 100644 index 0000000000..6f5cf17cf3 --- /dev/null +++ b/backends/arm/test/ops/test_tanh.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data) + ("zeros", torch.zeros(10, 10, 10, 10)), + ("ones", torch.ones(10, 10, 10)), + ("rand", torch.rand(10, 10) - 0.5), + ("randn_pos", torch.randn(10) + 10), + ("randn_neg", torch.randn(10) - 10), + ("ramp", torch.arange(-16, 16, 0.2)), +] + + +class TestTanh(unittest.TestCase): + class Tanh(torch.nn.Module): + def __init__(self): + super().__init__() + self.tanh = torch.nn.Tanh() + + def forward(self, x): + return self.tanh(x) + + def _test_tanh_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check(["torch.ops.aten.tanh.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check(["torch.ops.aten.tanh.default"]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_tanh_tosa_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.tanh.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_tanh_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + self._test_tanh_tosa_ethos_BI_pipeline( + common.get_u55_compile_spec(), module, test_data + ) + + def _test_tanh_tosa_u85_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + self._test_tanh_tosa_ethos_BI_pipeline( + common.get_u85_compile_spec(), module, test_data + ) + + @parameterized.expand(test_data_suite) + def test_tanh_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + ): + self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor): + self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): + self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor): + self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,))