Skip to content

Commit

Permalink
Arm backend: Add tanh operator
Browse files Browse the repository at this point in the history
Differential Revision: D64427390

Pull Request resolved: pytorch#6226
  • Loading branch information
SaoirseARM authored Oct 22, 2024
1 parent 8c96805 commit cb0f53e
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
op_squeeze,
op_sub,
op_sum,
op_tanh,
op_transpose,
op_unsqueeze,
op_view,
Expand Down
86 changes: 86 additions & 0 deletions backends/arm/operators/op_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class TanhVisitor(NodeVisitor):
target = "aten.tanh.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

assert len(node.all_input_nodes) == 1

if is_quant_node:
# Assume quantized input is 8 bit.
assert len(node.users) == 1

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)

table = tanh_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(table)

tosa_graph.addOperator(
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
)
else:
tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name])


def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
"""
Returns a table mapping 256 entries to tanh([qmin,qmax])
Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh
"""

def tanh(x):
# Convert quantized input to floating point tanh input space.
v = dequantize_value(x, in_quantargs)
# Compute tanh.
v = np.exp(-2.0 * v)
v = (1.0 - v) / (1.0 + v)

# Convert tanh output back to quantized space.
return quantize_value(v, out_quantargs)

return [
tanh(x)
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
]
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _annotate_one_to_one(
torch.ops.aten.reciprocal.default,
torch.ops.aten.rsqrt.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.tanh.default,
)
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in one_to_one_ops:
Expand Down
134 changes: 134 additions & 0 deletions backends/arm/test/ops/test_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data)
("zeros", torch.zeros(10, 10, 10, 10)),
("ones", torch.ones(10, 10, 10)),
("rand", torch.rand(10, 10) - 0.5),
("randn_pos", torch.randn(10) + 10),
("randn_neg", torch.randn(10) - 10),
("ramp", torch.arange(-16, 16, 0.2)),
]


class TestTanh(unittest.TestCase):
class Tanh(torch.nn.Module):
def __init__(self):
super().__init__()
self.tanh = torch.nn.Tanh()

def forward(self, x):
return self.tanh(x)

def _test_tanh_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check(["torch.ops.aten.tanh.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.check(["torch.ops.aten.tanh.default"])
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_tanh_tosa_ethos_BI_pipeline(
self,
compile_spec: list[CompileSpec],
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.tanh.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

def _test_tanh_tosa_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
self._test_tanh_tosa_ethos_BI_pipeline(
common.get_u55_compile_spec(), module, test_data
)

def _test_tanh_tosa_u85_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
self._test_tanh_tosa_ethos_BI_pipeline(
common.get_u85_compile_spec(), module, test_data
)

@parameterized.expand(test_data_suite)
def test_tanh_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
):
self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,))

@parameterized.expand(test_data_suite)
def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,))

@parameterized.expand(test_data_suite)
def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,))

@parameterized.expand(test_data_suite)
def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,))

0 comments on commit cb0f53e

Please sign in to comment.