Skip to content

Commit

Permalink
Add full op for Arm backend (pytorch#4073)
Browse files Browse the repository at this point in the history
Summary:
Implements  the full op which creates a tensor
of a given shape filled with a given value.
The shape and value are set at compile time,
i.e. can't be set by a tensor input.

Refactors tosa_quant_utils.is_quant_node
to handle nodes with no inputs (or outputs)

Does not add a full quantizer annotator, the op
needs to be quantized by a SharedQuantizationSpec

Change-Id: I1cebd1da1af5b9aa726a363431ffc30d8259a0ff

Pull Request resolved: pytorch#4073

Reviewed By: mergennachin

Differential Revision: D59259731

Pulled By: digantdesai

fbshipit-source-id: 621fec994bc2ebc4ad7abd51d9dbf1a5a4deed43
  • Loading branch information
Erik-Lundell authored and facebook-github-bot committed Jul 23, 2024
1 parent 908b5a5 commit 6556991
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 19 deletions.
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten._softmax.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
op_conv2d,
op_dequant,
op_div,
op_full,
op_get_item,
op_hardtanh,
op_mean_dim,
Expand Down
54 changes: 54 additions & 0 deletions backends/arm/operators/op_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from executorch.backends.arm.tosa_utils import tosa_shape
from torch.fx import Node


@register_node_visitor
class FullVisitor(NodeVisitor):
target = "aten.full.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

shape = tosa_shape(inputs[0].special, output.dim_order)

value = inputs[1].number
if is_quant_node:
qargs = get_quant_node_args(list(node.users)[0])
qvalue = np.clip(
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
)
dtype = ts.DType.INT8
data = np.full(shape, qvalue, dtype=np.int8)
else:
assert (
output.dtype == ts.DType.FP32
), "'Full' currently only supports FP32 for unquantized models."
dtype = ts.DType.FP32
data = np.full(shape, value, dtype=np.float32)

tosa_graph.addConst(shape, dtype, data, "full-const")
tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name])
160 changes: 160 additions & 0 deletions backends/arm/test/ops/test_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Tests the full op which creates a tensor of a given shape filled with a given value.
# The shape and value are set at compile time, i.e. can't be set by a tensor input.
#

import unittest
from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized


class TestFull(unittest.TestCase):
class Full(torch.nn.Module):
# A single full op
def forward(self):
return torch.full((3, 3), 4.5)

class AddConstFull(torch.nn.Module):
# Input + a full with constant value.
def forward(self, x: torch.Tensor):
return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x

class AddVariableFull(torch.nn.Module):
sizes = [
(5),
(5, 5),
(5, 5, 5),
(1, 5, 5, 5),
]
test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes]

def forward(self, x: torch.Tensor, y):
# Input + a full with the shape from the input and a given value 'y'.
return x + torch.full(x.shape, y)

def _test_full_tosa_MI_pipeline(
self,
module: torch.nn.Module,
example_data: Tuple,
test_data: Tuple | None = None,
):
if test_data is None:
test_data = example_data
(
ArmTester(
module,
example_inputs=example_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_full_tosa_BI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple,
permute_memory_to_nhwc: bool,
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(
permute_memory_to_nhwc=permute_memory_to_nhwc
),
)
.quantize()
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize()
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

def test_only_full_tosa_MI(self):
self._test_full_tosa_MI_pipeline(self.Full(), ())

def test_const_full_tosa_MI(self):
_input = torch.rand((2, 2, 3, 3)) * 10
self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,))

def test_const_full_nhwc_tosa_BI(self):
_input = torch.rand((2, 2, 3, 3)) * 10
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True)

@parameterized.expand(AddVariableFull.test_parameters)
def test_full_tosa_MI(self, test_tensor: Tuple):
self._test_full_tosa_MI_pipeline(
self.AddVariableFull(), example_data=test_tensor
)

@parameterized.expand(AddVariableFull.test_parameters)
def test_full_tosa_BI(self, test_tensor: Tuple):
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False)

@parameterized.expand(AddVariableFull.test_parameters)
def test_full_u55_BI(self, test_tensor: Tuple):
self._test_full_tosa_u55_pipeline(
self.AddVariableFull(),
test_tensor,
)

# This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support.
@unittest.expectedFailure
def test_integer_value(self):
_input = torch.ones((2, 2))
integer_fill_value = 1
self._test_full_tosa_MI_pipeline(
self.AddVariableFull(), example_data=(_input, integer_fill_value)
)

# This fails since the fill value in the full tensor is set at compile time by the example data (1.).
# Test data tries to set it again at runtime (to 2.) but it doesn't do anything.
# In eager mode, the fill value can be set at runtime, causing the outputs to not match.
@unittest.expectedFailure
def test_set_value_at_runtime(self):
_input = torch.ones((2, 2))
example_fill_value = 1.0
test_fill_value = 2.0
self._test_full_tosa_MI_pipeline(
self.AddVariableFull(),
example_data=(_input, example_fill_value),
test_data=(_input, test_fill_value),
)
14 changes: 11 additions & 3 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,15 @@ def run_method_and_compare_outputs(
else:
test_input = reference_input

# Test parameters can include constants that are used in eager mode but are already set as attributes
# in TOSA. Therefore, only accept torch.Tensor inputs.
test_input = [
tensor for tensor in test_input if isinstance(tensor, torch.Tensor)
]

input_shapes = [
generated_input.shape for generated_input in reference_input
generated_input.shape if hasattr(generated_input, "shape") else (1,)
for generated_input in reference_input
]
print(f"Run {run_iteration} with input shapes: {input_shapes}")

Expand All @@ -274,7 +281,7 @@ def transpose_data_format(
dim_order = (0, 2, 3, 1)
inputs_transposed = list(data)
for i in range(len(data)):
if len(data[i].shape) == 4:
if hasattr(data[i], "shape") and len(data[i].shape) == 4:
inputs_transposed[i] = np.transpose(data[i], dim_order)
return tuple(inputs_transposed)

Expand All @@ -298,7 +305,8 @@ def _compare_outputs(
path_to_tosa_files = self.runner_util.intermediate_path

export_stage = self.stages.get(self.stage_name(tester.Export), None)
if export_stage is not None:
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
if export_stage is not None and quantize_stage is not None:
input_names = _get_input_names(export_stage.artifact)
output_node = _get_output_node(export_stage.artifact)
qp_input = _get_input_quantization_params(
Expand Down
36 changes: 20 additions & 16 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,26 @@ class QuantArgs(NamedTuple):


def is_quant_node(node: torch.fx.Node):
consumer_node = list(node.users)[0]
input = node.all_input_nodes[0]

# For Rank > 2 Linear layers, the quant node is after the view_copy
if (
node.target == exir_ops.edge.aten.addmm.default
and consumer_node.target == exir_ops.edge.aten.view_copy.default
):
consumer_consumer_node = list(consumer_node.users)[0]
return True if consumer_consumer_node.target == q_op else False

return (
consumer_node.target == q_op
or node.target in dq_q_ops
or input.target in dq_q_ops
)

consumer_node_condition = False
if len(list(node.users)) > 0:
consumer_node = list(node.users)[0]

# For Rank > 2 Linear layers, the quant node is after the view_copy
if (
node.target == exir_ops.edge.aten.addmm.default
and consumer_node.target == exir_ops.edge.aten.view_copy.default
):
consumer_consumer_node = list(consumer_node.users)[0]
return True if consumer_consumer_node.target == q_op else False
consumer_node_condition = consumer_node.target == q_op

input_node_condition = False
if len(node.all_input_nodes) > 0:
input = node.all_input_nodes[0]
input_node_condition = input.target in dq_q_ops

return node.target in dq_q_ops or consumer_node_condition or input_node_condition


def get_quant_node_dtype(node: torch.fx.Node):
Expand Down

0 comments on commit 6556991

Please sign in to comment.