forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add full op for Arm backend (pytorch#4073)
Summary: Implements the full op which creates a tensor of a given shape filled with a given value. The shape and value are set at compile time, i.e. can't be set by a tensor input. Refactors tosa_quant_utils.is_quant_node to handle nodes with no inputs (or outputs) Does not add a full quantizer annotator, the op needs to be quantized by a SharedQuantizationSpec Change-Id: I1cebd1da1af5b9aa726a363431ffc30d8259a0ff Pull Request resolved: pytorch#4073 Reviewed By: mergennachin Differential Revision: D59259731 Pulled By: digantdesai fbshipit-source-id: 621fec994bc2ebc4ad7abd51d9dbf1a5a4deed43
- Loading branch information
1 parent
908b5a5
commit 6556991
Showing
6 changed files
with
247 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
op_conv2d, | ||
op_dequant, | ||
op_div, | ||
op_full, | ||
op_get_item, | ||
op_hardtanh, | ||
op_mean_dim, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
import serializer.tosa_serializer as ts | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args | ||
from executorch.backends.arm.tosa_utils import tosa_shape | ||
from torch.fx import Node | ||
|
||
|
||
@register_node_visitor | ||
class FullVisitor(NodeVisitor): | ||
target = "aten.full.default" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
|
||
shape = tosa_shape(inputs[0].special, output.dim_order) | ||
|
||
value = inputs[1].number | ||
if is_quant_node: | ||
qargs = get_quant_node_args(list(node.users)[0]) | ||
qvalue = np.clip( | ||
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax | ||
) | ||
dtype = ts.DType.INT8 | ||
data = np.full(shape, qvalue, dtype=np.int8) | ||
else: | ||
assert ( | ||
output.dtype == ts.DType.FP32 | ||
), "'Full' currently only supports FP32 for unquantized models." | ||
dtype = ts.DType.FP32 | ||
data = np.full(shape, value, dtype=np.float32) | ||
|
||
tosa_graph.addConst(shape, dtype, data, "full-const") | ||
tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# | ||
# Tests the full op which creates a tensor of a given shape filled with a given value. | ||
# The shape and value are set at compile time, i.e. can't be set by a tensor input. | ||
# | ||
|
||
import unittest | ||
from typing import Tuple | ||
|
||
import torch | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from parameterized import parameterized | ||
|
||
|
||
class TestFull(unittest.TestCase): | ||
class Full(torch.nn.Module): | ||
# A single full op | ||
def forward(self): | ||
return torch.full((3, 3), 4.5) | ||
|
||
class AddConstFull(torch.nn.Module): | ||
# Input + a full with constant value. | ||
def forward(self, x: torch.Tensor): | ||
return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x | ||
|
||
class AddVariableFull(torch.nn.Module): | ||
sizes = [ | ||
(5), | ||
(5, 5), | ||
(5, 5, 5), | ||
(1, 5, 5, 5), | ||
] | ||
test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes] | ||
|
||
def forward(self, x: torch.Tensor, y): | ||
# Input + a full with the shape from the input and a given value 'y'. | ||
return x + torch.full(x.shape, y) | ||
|
||
def _test_full_tosa_MI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
example_data: Tuple, | ||
test_data: Tuple | None = None, | ||
): | ||
if test_data is None: | ||
test_data = example_data | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=example_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.check_count({"torch.ops.aten.full.default": 1}) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_full_tosa_BI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
test_data: Tuple, | ||
permute_memory_to_nhwc: bool, | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec( | ||
permute_memory_to_nhwc=permute_memory_to_nhwc | ||
), | ||
) | ||
.quantize() | ||
.export() | ||
.check_count({"torch.ops.aten.full.default": 1}) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_u55_compile_spec(), | ||
) | ||
.quantize() | ||
.export() | ||
.check_count({"torch.ops.aten.full.default": 1}) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
def test_only_full_tosa_MI(self): | ||
self._test_full_tosa_MI_pipeline(self.Full(), ()) | ||
|
||
def test_const_full_tosa_MI(self): | ||
_input = torch.rand((2, 2, 3, 3)) * 10 | ||
self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,)) | ||
|
||
def test_const_full_nhwc_tosa_BI(self): | ||
_input = torch.rand((2, 2, 3, 3)) * 10 | ||
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True) | ||
|
||
@parameterized.expand(AddVariableFull.test_parameters) | ||
def test_full_tosa_MI(self, test_tensor: Tuple): | ||
self._test_full_tosa_MI_pipeline( | ||
self.AddVariableFull(), example_data=test_tensor | ||
) | ||
|
||
@parameterized.expand(AddVariableFull.test_parameters) | ||
def test_full_tosa_BI(self, test_tensor: Tuple): | ||
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False) | ||
|
||
@parameterized.expand(AddVariableFull.test_parameters) | ||
def test_full_u55_BI(self, test_tensor: Tuple): | ||
self._test_full_tosa_u55_pipeline( | ||
self.AddVariableFull(), | ||
test_tensor, | ||
) | ||
|
||
# This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support. | ||
@unittest.expectedFailure | ||
def test_integer_value(self): | ||
_input = torch.ones((2, 2)) | ||
integer_fill_value = 1 | ||
self._test_full_tosa_MI_pipeline( | ||
self.AddVariableFull(), example_data=(_input, integer_fill_value) | ||
) | ||
|
||
# This fails since the fill value in the full tensor is set at compile time by the example data (1.). | ||
# Test data tries to set it again at runtime (to 2.) but it doesn't do anything. | ||
# In eager mode, the fill value can be set at runtime, causing the outputs to not match. | ||
@unittest.expectedFailure | ||
def test_set_value_at_runtime(self): | ||
_input = torch.ones((2, 2)) | ||
example_fill_value = 1.0 | ||
test_fill_value = 2.0 | ||
self._test_full_tosa_MI_pipeline( | ||
self.AddVariableFull(), | ||
example_data=(_input, example_fill_value), | ||
test_data=(_input, test_fill_value), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters