forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Differential Revision: D64427390 Pull Request resolved: pytorch#6226
- Loading branch information
1 parent
8c96805
commit cb0f53e
Showing
5 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
op_squeeze, | ||
op_sub, | ||
op_sum, | ||
op_tanh, | ||
op_transpose, | ||
op_unsqueeze, | ||
op_view, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
import serializer.tosa_serializer as ts | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
|
||
from executorch.backends.arm.tosa_quant_utils import ( | ||
dequantize_value, | ||
get_quant_node_args, | ||
QuantArgs, | ||
quantize_value, | ||
) | ||
from serializer.tosa_serializer import TosaOp | ||
from torch.fx import Node | ||
|
||
|
||
@register_node_visitor | ||
class TanhVisitor(NodeVisitor): | ||
target = "aten.tanh.default" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
|
||
assert len(node.all_input_nodes) == 1 | ||
|
||
if is_quant_node: | ||
# Assume quantized input is 8 bit. | ||
assert len(node.users) == 1 | ||
|
||
# Create attribute for 8 bit table lookup. | ||
input_node = node.all_input_nodes[0] | ||
in_quantargs = get_quant_node_args(input_node) | ||
output_node = list(node.users)[0] | ||
out_quantargs = get_quant_node_args(output_node) | ||
|
||
table = tanh_table_8bit(in_quantargs, out_quantargs) | ||
table_attr = ts.TosaSerializerAttribute() | ||
table_attr.TableAttribute(table) | ||
|
||
tosa_graph.addOperator( | ||
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr | ||
) | ||
else: | ||
tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) | ||
|
||
|
||
def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): | ||
""" | ||
Returns a table mapping 256 entries to tanh([qmin,qmax]) | ||
Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh | ||
""" | ||
|
||
def tanh(x): | ||
# Convert quantized input to floating point tanh input space. | ||
v = dequantize_value(x, in_quantargs) | ||
# Compute tanh. | ||
v = np.exp(-2.0 * v) | ||
v = (1.0 - v) / (1.0 + v) | ||
|
||
# Convert tanh output back to quantized space. | ||
return quantize_value(v, out_quantargs) | ||
|
||
return [ | ||
tanh(x) | ||
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
from parameterized import parameterized | ||
|
||
|
||
test_data_suite = [ | ||
# (test_name, test_data) | ||
("zeros", torch.zeros(10, 10, 10, 10)), | ||
("ones", torch.ones(10, 10, 10)), | ||
("rand", torch.rand(10, 10) - 0.5), | ||
("randn_pos", torch.randn(10) + 10), | ||
("randn_neg", torch.randn(10) - 10), | ||
("ramp", torch.arange(-16, 16, 0.2)), | ||
] | ||
|
||
|
||
class TestTanh(unittest.TestCase): | ||
class Tanh(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.tanh = torch.nn.Tanh() | ||
|
||
def forward(self, x): | ||
return self.tanh(x) | ||
|
||
def _test_tanh_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.check(["torch.ops.aten.tanh.default"]) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.quantize() | ||
.export() | ||
.check(["torch.ops.aten.tanh.default"]) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_tanh_tosa_ethos_BI_pipeline( | ||
self, | ||
compile_spec: list[CompileSpec], | ||
module: torch.nn.Module, | ||
test_data: Tuple[torch.tensor], | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize() | ||
.export() | ||
.check_count({"torch.ops.aten.tanh.default": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
def _test_tanh_tosa_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
self._test_tanh_tosa_ethos_BI_pipeline( | ||
common.get_u55_compile_spec(), module, test_data | ||
) | ||
|
||
def _test_tanh_tosa_u85_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
self._test_tanh_tosa_ethos_BI_pipeline( | ||
common.get_u85_compile_spec(), module, test_data | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_tanh_tosa_MI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
): | ||
self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor): | ||
self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): | ||
self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor): | ||
self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,)) |