diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index bd83ad1974..ea41f8381a 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -42,6 +42,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.full.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten._softmax.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 91db0af5d4..4b783903f5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -12,6 +12,7 @@ op_conv2d, op_dequant, op_div, + op_full, op_get_item, op_hardtanh, op_mean_dim, diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py new file mode 100644 index 0000000000..f929b02ee6 --- /dev/null +++ b/backends/arm/operators/op_full.py @@ -0,0 +1,54 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import numpy as np + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_utils import tosa_shape +from torch.fx import Node + + +@register_node_visitor +class FullVisitor(NodeVisitor): + target = "aten.full.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + shape = tosa_shape(inputs[0].special, output.dim_order) + + value = inputs[1].number + if is_quant_node: + qargs = get_quant_node_args(list(node.users)[0]) + qvalue = np.clip( + np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax + ) + dtype = ts.DType.INT8 + data = np.full(shape, qvalue, dtype=np.int8) + else: + assert ( + output.dtype == ts.DType.FP32 + ), "'Full' currently only supports FP32 for unquantized models." + dtype = ts.DType.FP32 + data = np.full(shape, value, dtype=np.float32) + + tosa_graph.addConst(shape, dtype, data, "full-const") + tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name]) diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py new file mode 100644 index 0000000000..4f01b1c8f9 --- /dev/null +++ b/backends/arm/test/ops/test_full.py @@ -0,0 +1,160 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Tests the full op which creates a tensor of a given shape filled with a given value. +# The shape and value are set at compile time, i.e. can't be set by a tensor input. +# + +import unittest +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class TestFull(unittest.TestCase): + class Full(torch.nn.Module): + # A single full op + def forward(self): + return torch.full((3, 3), 4.5) + + class AddConstFull(torch.nn.Module): + # Input + a full with constant value. + def forward(self, x: torch.Tensor): + return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x + + class AddVariableFull(torch.nn.Module): + sizes = [ + (5), + (5, 5), + (5, 5, 5), + (1, 5, 5, 5), + ] + test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes] + + def forward(self, x: torch.Tensor, y): + # Input + a full with the shape from the input and a given value 'y'. + return x + torch.full(x.shape, y) + + def _test_full_tosa_MI_pipeline( + self, + module: torch.nn.Module, + example_data: Tuple, + test_data: Tuple | None = None, + ): + if test_data is None: + test_data = example_data + ( + ArmTester( + module, + example_inputs=example_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.full.default": 1}) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_full_tosa_BI_pipeline( + self, + module: torch.nn.Module, + test_data: Tuple, + permute_memory_to_nhwc: bool, + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute_memory_to_nhwc + ), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.full.default": 1}) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.full.default": 1}) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def test_only_full_tosa_MI(self): + self._test_full_tosa_MI_pipeline(self.Full(), ()) + + def test_const_full_tosa_MI(self): + _input = torch.rand((2, 2, 3, 3)) * 10 + self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,)) + + def test_const_full_nhwc_tosa_BI(self): + _input = torch.rand((2, 2, 3, 3)) * 10 + self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True) + + @parameterized.expand(AddVariableFull.test_parameters) + def test_full_tosa_MI(self, test_tensor: Tuple): + self._test_full_tosa_MI_pipeline( + self.AddVariableFull(), example_data=test_tensor + ) + + @parameterized.expand(AddVariableFull.test_parameters) + def test_full_tosa_BI(self, test_tensor: Tuple): + self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False) + + @parameterized.expand(AddVariableFull.test_parameters) + def test_full_u55_BI(self, test_tensor: Tuple): + self._test_full_tosa_u55_pipeline( + self.AddVariableFull(), + test_tensor, + ) + + # This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support. + @unittest.expectedFailure + def test_integer_value(self): + _input = torch.ones((2, 2)) + integer_fill_value = 1 + self._test_full_tosa_MI_pipeline( + self.AddVariableFull(), example_data=(_input, integer_fill_value) + ) + + # This fails since the fill value in the full tensor is set at compile time by the example data (1.). + # Test data tries to set it again at runtime (to 2.) but it doesn't do anything. + # In eager mode, the fill value can be set at runtime, causing the outputs to not match. + @unittest.expectedFailure + def test_set_value_at_runtime(self): + _input = torch.ones((2, 2)) + example_fill_value = 1.0 + test_fill_value = 2.0 + self._test_full_tosa_MI_pipeline( + self.AddVariableFull(), + example_data=(_input, example_fill_value), + test_data=(_input, test_fill_value), + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index c38633ee23..97ab67b3d1 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -249,8 +249,15 @@ def run_method_and_compare_outputs( else: test_input = reference_input + # Test parameters can include constants that are used in eager mode but are already set as attributes + # in TOSA. Therefore, only accept torch.Tensor inputs. + test_input = [ + tensor for tensor in test_input if isinstance(tensor, torch.Tensor) + ] + input_shapes = [ - generated_input.shape for generated_input in reference_input + generated_input.shape if hasattr(generated_input, "shape") else (1,) + for generated_input in reference_input ] print(f"Run {run_iteration} with input shapes: {input_shapes}") @@ -274,7 +281,7 @@ def transpose_data_format( dim_order = (0, 2, 3, 1) inputs_transposed = list(data) for i in range(len(data)): - if len(data[i].shape) == 4: + if hasattr(data[i], "shape") and len(data[i].shape) == 4: inputs_transposed[i] = np.transpose(data[i], dim_order) return tuple(inputs_transposed) @@ -298,7 +305,8 @@ def _compare_outputs( path_to_tosa_files = self.runner_util.intermediate_path export_stage = self.stages.get(self.stage_name(tester.Export), None) - if export_stage is not None: + quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None) + if export_stage is not None and quantize_stage is not None: input_names = _get_input_names(export_stage.artifact) output_node = _get_output_node(export_stage.artifact) qp_input = _get_input_quantization_params( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index d467e4c00e..0379780ed2 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -27,22 +27,26 @@ class QuantArgs(NamedTuple): def is_quant_node(node: torch.fx.Node): - consumer_node = list(node.users)[0] - input = node.all_input_nodes[0] - - # For Rank > 2 Linear layers, the quant node is after the view_copy - if ( - node.target == exir_ops.edge.aten.addmm.default - and consumer_node.target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - return True if consumer_consumer_node.target == q_op else False - - return ( - consumer_node.target == q_op - or node.target in dq_q_ops - or input.target in dq_q_ops - ) + + consumer_node_condition = False + if len(list(node.users)) > 0: + consumer_node = list(node.users)[0] + + # For Rank > 2 Linear layers, the quant node is after the view_copy + if ( + node.target == exir_ops.edge.aten.addmm.default + and consumer_node.target == exir_ops.edge.aten.view_copy.default + ): + consumer_consumer_node = list(consumer_node.users)[0] + return True if consumer_consumer_node.target == q_op else False + consumer_node_condition = consumer_node.target == q_op + + input_node_condition = False + if len(node.all_input_nodes) > 0: + input = node.all_input_nodes[0] + input_node_condition = input.target in dq_q_ops + + return node.target in dq_q_ops or consumer_node_condition or input_node_condition def get_quant_node_dtype(node: torch.fx.Node):