Skip to content

Commit

Permalink
Add initial support for rshift
Browse files Browse the repository at this point in the history
U55 is restricted to round=True which may cause numerical differences
between TOSA and PyTorch.

Signed-off-by: Oscar Andersson <[email protected]>
Change-Id: I280e0dd0573b31333f6386b48d20105023719eb7
  • Loading branch information
oscarandersson8218 authored and freddan80 committed Nov 25, 2024
1 parent 12ce0ce commit fbee0c8
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 2 deletions.
2 changes: 1 addition & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ethosu_compile_spec(
self.compiler_flags.append(extra_flags)

base_tosa_version = "TOSA-0.80.0+BI"
if "U55" in config:
if "u55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@

# pyre-unsafe

from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa
from . import ( # noqa
mean_dim_support,
right_shift_support,
tosa_supported_operators,
var_correction_support,
)
35 changes: 35 additions & 0 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


@register_tosa_support_check
class RightShiftSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.__rshift__.Scalar]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# TODO MLETORCH-525 Remove warning
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
logging.warning(f"{node.target} may introduce one-off errors.")
return True
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
op_reciprocal,
op_relu,
op_repeat,
op_rshift,
op_rsqrt,
op_select,
op_sigmoid,
Expand Down
99 changes: 99 additions & 0 deletions backends/arm/operators/op_rshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_specification import Tosa_0_80
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class RshiftVisitor(NodeVisitor):
target = "aten.__rshift__.Scalar"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_shape = inputs[0].shape
input_0_rank = len(input_shape)
shift_expanded_shape = [1] * input_0_rank
dtype = node.meta["val"].dtype
attr = ts.TosaSerializerAttribute()
cast_input = False
cast_output = False
round = False
cast_type = dtype
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
# U55 only supports INT32 and round == True
# TODO MLETORCH-525 Emulate round == False with different decomposition
if dtype != torch.int32:
cast_input = True
cast_output = True
cast_type = torch.int32
round = True
attr.ArithmeticRightShiftAttribute(round=round)

if cast_input:
# input needs to be casted to INT32
shift_input = tosa_graph.addIntermediate(
shape=tosa_shape(input_shape, inputs[0].dim_order),
dtype=map_dtype(cast_type),
)
tosa_graph.addOperator(
TosaOp.Op().CAST,
[inputs[0].name],
[shift_input.name],
None,
)
else:
shift_input = inputs[0]
if cast_output:
# add intermediate tensor for right shift
shift = tosa_graph.addIntermediate(
shape=tosa_shape(input_shape, inputs[0].dim_order),
dtype=map_dtype(cast_type),
)
else:
shift = output
# create tensor with same rank as inputs[0]
data = torch.full(
shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype
)
shift_const_name = node.name + "-shift_const"
tosa_graph.addConst(
shift_expanded_shape,
map_dtype(cast_type),
data.detach().numpy(),
shift_const_name,
)
# add right shift operator
tosa_graph.addOperator(
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
[shift_input.name, shift_const_name],
[shift.name],
attr,
)
if cast_output:
# cast output to original output dtype
tosa_graph.addOperator(
TosaOp.Op().CAST,
[shift.name],
[output.name],
None,
)
90 changes: 90 additions & 0 deletions backends/arm/test/ops/test_rshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized


class TestRshift(unittest.TestCase):
"""
Tests arithmetic right shift
"""

class Rshift(torch.nn.Module):
test_data = [
((torch.IntTensor(5, 5), 2),),
((torch.IntTensor(1, 2, 3, 4), 3),),
((torch.ShortTensor(1, 5, 3, 4), 5),),
((torch.CharTensor(10, 12, 3, 4), 1),),
]

def forward(self, x: torch.Tensor, shift: int):
return x >> shift

def _test_rshift_tosa_MI(self, test_data):
(
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_rshift_tosa_BI(self, test_data):
(
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
# TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
# .run_method_and_compare_outputs(inputs=test_data)
)

def _test_rshift_ethosu_BI(self, test_data, compile_spec):
return (
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
)

@parameterized.expand(Rshift.test_data)
def test_rshift_tosa_MI(self, test_data):
self._test_rshift_tosa_MI(test_data)

@parameterized.expand(Rshift.test_data)
def test_rshift_tosa_BI(self, test_data):
self._test_rshift_tosa_BI(test_data)

# TODO Enable FVP testing
@parameterized.expand(Rshift.test_data)
def test_rshift_u55_BI(self, test_data):
compile_spec = common.get_u55_compile_spec()
self._test_rshift_ethosu_BI(test_data, compile_spec)

# TODO Enable FVP testing
@parameterized.expand(Rshift.test_data)
def test_rshift_u85_BI(self, test_data):
compile_spec = common.get_u85_compile_spec()
self._test_rshift_ethosu_BI(test_data, compile_spec)

0 comments on commit fbee0c8

Please sign in to comment.