forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
U55 is restricted to round=True which may cause numerical differences between TOSA and PyTorch. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I280e0dd0573b31333f6386b48d20105023719eb7
- Loading branch information
1 parent
12ce0ce
commit fbee0c8
Showing
6 changed files
with
232 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
|
||
import torch.fx as fx | ||
from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
register_tosa_support_check, | ||
SupportedTOSAOperatorCheck, | ||
) | ||
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.WARNING) | ||
|
||
|
||
@register_tosa_support_check | ||
class RightShiftSupported(SupportedTOSAOperatorCheck): | ||
targets = [exir_ops.edge.aten.__rshift__.Scalar] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80.0+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80.0+MI"), | ||
] | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): | ||
|
||
# TODO MLETORCH-525 Remove warning | ||
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: | ||
logging.warning(f"{node.target} may introduce one-off errors.") | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
op_reciprocal, | ||
op_relu, | ||
op_repeat, | ||
op_rshift, | ||
op_rsqrt, | ||
op_select, | ||
op_sigmoid, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import serializer.tosa_serializer as ts | ||
import torch | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg | ||
from executorch.backends.arm.tosa_specification import Tosa_0_80 | ||
from executorch.backends.arm.tosa_utils import tosa_shape | ||
from serializer.tosa_serializer import TosaOp | ||
|
||
|
||
@register_node_visitor | ||
class RshiftVisitor(NodeVisitor): | ||
target = "aten.__rshift__.Scalar" | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
input_shape = inputs[0].shape | ||
input_0_rank = len(input_shape) | ||
shift_expanded_shape = [1] * input_0_rank | ||
dtype = node.meta["val"].dtype | ||
attr = ts.TosaSerializerAttribute() | ||
cast_input = False | ||
cast_output = False | ||
round = False | ||
cast_type = dtype | ||
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: | ||
# U55 only supports INT32 and round == True | ||
# TODO MLETORCH-525 Emulate round == False with different decomposition | ||
if dtype != torch.int32: | ||
cast_input = True | ||
cast_output = True | ||
cast_type = torch.int32 | ||
round = True | ||
attr.ArithmeticRightShiftAttribute(round=round) | ||
|
||
if cast_input: | ||
# input needs to be casted to INT32 | ||
shift_input = tosa_graph.addIntermediate( | ||
shape=tosa_shape(input_shape, inputs[0].dim_order), | ||
dtype=map_dtype(cast_type), | ||
) | ||
tosa_graph.addOperator( | ||
TosaOp.Op().CAST, | ||
[inputs[0].name], | ||
[shift_input.name], | ||
None, | ||
) | ||
else: | ||
shift_input = inputs[0] | ||
if cast_output: | ||
# add intermediate tensor for right shift | ||
shift = tosa_graph.addIntermediate( | ||
shape=tosa_shape(input_shape, inputs[0].dim_order), | ||
dtype=map_dtype(cast_type), | ||
) | ||
else: | ||
shift = output | ||
# create tensor with same rank as inputs[0] | ||
data = torch.full( | ||
shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype | ||
) | ||
shift_const_name = node.name + "-shift_const" | ||
tosa_graph.addConst( | ||
shift_expanded_shape, | ||
map_dtype(cast_type), | ||
data.detach().numpy(), | ||
shift_const_name, | ||
) | ||
# add right shift operator | ||
tosa_graph.addOperator( | ||
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, | ||
[shift_input.name, shift_const_name], | ||
[shift.name], | ||
attr, | ||
) | ||
if cast_output: | ||
# cast output to original output dtype | ||
tosa_graph.addOperator( | ||
TosaOp.Op().CAST, | ||
[shift.name], | ||
[output.name], | ||
None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from parameterized import parameterized | ||
|
||
|
||
class TestRshift(unittest.TestCase): | ||
""" | ||
Tests arithmetic right shift | ||
""" | ||
|
||
class Rshift(torch.nn.Module): | ||
test_data = [ | ||
((torch.IntTensor(5, 5), 2),), | ||
((torch.IntTensor(1, 2, 3, 4), 3),), | ||
((torch.ShortTensor(1, 5, 3, 4), 5),), | ||
((torch.CharTensor(10, 12, 3, 4), 1),), | ||
] | ||
|
||
def forward(self, x: torch.Tensor, shift: int): | ||
return x >> shift | ||
|
||
def _test_rshift_tosa_MI(self, test_data): | ||
( | ||
ArmTester( | ||
self.Rshift(), | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), | ||
) | ||
.export() | ||
.to_edge_transform_and_lower() | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_rshift_tosa_BI(self, test_data): | ||
( | ||
ArmTester( | ||
self.Rshift(), | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), | ||
) | ||
.quantize() | ||
.export() | ||
.to_edge_transform_and_lower() | ||
.to_executorch() | ||
# TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO | ||
# .run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_rshift_ethosu_BI(self, test_data, compile_spec): | ||
return ( | ||
ArmTester( | ||
self.Rshift(), | ||
example_inputs=test_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize() | ||
.export() | ||
.to_edge_transform_and_lower() | ||
.to_executorch() | ||
) | ||
|
||
@parameterized.expand(Rshift.test_data) | ||
def test_rshift_tosa_MI(self, test_data): | ||
self._test_rshift_tosa_MI(test_data) | ||
|
||
@parameterized.expand(Rshift.test_data) | ||
def test_rshift_tosa_BI(self, test_data): | ||
self._test_rshift_tosa_BI(test_data) | ||
|
||
# TODO Enable FVP testing | ||
@parameterized.expand(Rshift.test_data) | ||
def test_rshift_u55_BI(self, test_data): | ||
compile_spec = common.get_u55_compile_spec() | ||
self._test_rshift_ethosu_BI(test_data, compile_spec) | ||
|
||
# TODO Enable FVP testing | ||
@parameterized.expand(Rshift.test_data) | ||
def test_rshift_u85_BI(self, test_data): | ||
compile_spec = common.get_u85_compile_spec() | ||
self._test_rshift_ethosu_BI(test_data, compile_spec) |