Skip to content

Commit

Permalink
Add sub operator for Arm backend (pytorch#4074)
Browse files Browse the repository at this point in the history
Summary:
Implemented node visitor, quantizer, and test.
TOSA MI and BI passes, U55 BI fails (expectedly) on Vela compilation.

Refactored code shared between sub and add.

Change-Id: Ifc9fc4ae083f3ed868ad763e4301e5fe87468a25

Pull Request resolved: pytorch#4074

Reviewed By: mergennachin

Differential Revision: D59259308

Pulled By: digantdesai

fbshipit-source-id: ce10e9b1a583e6374e5f1c5815dc11c0d0e7aa5b
  • Loading branch information
Erik-Lundell authored and facebook-github-bot committed Jul 24, 2024
1 parent 47d309a commit 11b2fcb
Show file tree
Hide file tree
Showing 13 changed files with 443 additions and 96 deletions.
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
op_quant,
op_sigmoid,
op_softmax,
op_sub,
op_view,
)
79 changes: 23 additions & 56 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@

from typing import List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
build_rescale_from_int32,
build_rescale_to_int32,
)
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
Expand All @@ -29,75 +27,44 @@ def __init__(self, *args):

def define_node(
self,
node: torch.fx.Node,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
# Single input or not
if len(node.all_input_nodes) == 1:
input_node_A = node.all_input_nodes[0]
input_node_B = node.all_input_nodes[0]
else:
input_node_A, input_node_B = node.all_input_nodes

# Get input scale_factor and zero_points for A, B
input_A, input_A_scale, input_A_zp, _, _, _ = getNodeArgs(input_node_A)
input_B, input_B_scale, input_B_zp, _, _, _ = getNodeArgs(input_node_B)

# Scale the int8 quantized input to a common scale in the integer
# domain.
min_scale = min(input_A_scale.number, input_B_scale.number)
inputA_rescale_scale = input_A_scale.number / min_scale
inputB_rescale_scale = input_B_scale.number / min_scale

input_A.shape = tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tosa_shape(input_B.shape, input_B.dim_order)
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
input_nodes = tutils.get_two_inputs(node)

input_A_rescaled_to_int32 = build_rescale_to_int32(
tosa_graph,
input_A,
input_A_zp.number,
inputA_rescale_scale,
# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
)

input_B_rescaled_to_int32 = build_rescale_to_int32(
tosa_graph,
input_B,
input_B_zp.number,
inputB_rescale_scale,
# Preapre sub output tensor
broadcasted_shape = tutils.broadcast_shapes(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)

## Do the INT32 Add
add_res = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
input_A_rescaled_to_int32.name,
input_B_rescaled_to_int32.name,
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[add_res.name],
[add_output.name],
None,
)

# Output
output_node = list(node.users)[0]
_, output_scale, output_zp, _, _, _ = getNodeArgs(output_node)
output_rescale_scale = min_scale / output_scale.number

# Rescale Back to INT8
build_rescale_from_int32(
tosa_graph,
add_res.name,
output.name,
output_zp.number,
output_rescale_scale,
)
# Scale output back to 8 bit
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
else:
# FP32 Add lowering
tosa_graph.addOperator(
TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], [output.name], None
TosaOp.Op().ADD,
[inputs[0].name, inputs[1].name],
[output.name],
None,
)
69 changes: 69 additions & 0 deletions backends/arm/operators/op_sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class SubVisitor(NodeVisitor):
target = "aten.sub.Tensor"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
input_nodes = tutils.get_two_inputs(node)

# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
)

# Preapre sub output tensor
broadcasted_shape = tutils.broadcast_shapes(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)

# Do the INT32 Sub
tosa_graph.addOperator(
TosaOp.Op().SUB,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[sub_output.name],
)

# Scale output back to 8 bit
tqutils.rescale_node_back_to_int8(node, sub_output, scale, tosa_graph)
else:
# FP32 Sub lowering
tosa_graph.addOperator(
TosaOp.Op().SUB,
[inputs[0].name, inputs[1].name],
[output.name],
None,
)
2 changes: 2 additions & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
"sub": [[torch.sub]],
}
return copy.deepcopy(supported_operators)

Expand Down Expand Up @@ -254,6 +255,7 @@ class ArmQuantizer(Quantizer):
"adaptive_avg_pool2d",
"max_pool2d",
"add",
"sub",
"mul",
"sigmoid",
]
Expand Down
50 changes: 49 additions & 1 deletion backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Callable, cast, List

import torch
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix

Expand Down Expand Up @@ -54,7 +55,53 @@ def mark_nodes_as_annotated(nodes: List[Node]) -> None:
node.meta["quantization_annotation"]._annotated = True


def is_input_large_scalar(node: Node, gm: GraphModule) -> bool:
def get_shared_qspec(
node: Node, gm: GraphModule, quantization_config: QuantizationConfig
):
"""Returns a Quantization constallation with a SharedQuantizationSpec for the inputs
and output to the parameter 'node'.
Parameters:
node: a node with two inputs that should share Quantization parameters.
gm: The GraphModule containing the node. Used to inspect global graph features.
quantization_config : a QuantizationConfig with the input QuantizationSpec to share
Returns:
input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to
the correct QuantizationSpec.
shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec.
Both outputs are None if one of the inputs is a node that can't be quantized.
"""
input_act0 = node.args[0]
input_act1 = node.args[1]

input_act_qspec = quantization_config.get_input_act_qspec()
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))

input_qspec_map = {}
if isinstance(input_act0, Node):
if not is_input_ok_for_quantization(input_act0, gm):
return None, None
input_qspec_map[input_act0] = input_act_qspec

if isinstance(input_act1, Node):
if not is_input_ok_for_quantization(input_act1, gm):
return None, None
if input_act0 is not input_act1:
input_qspec_map[input_act1] = shared_with_input0_qspec
return input_qspec_map, shared_with_input0_qspec


def is_input_ok_for_quantization(input_act: Node, gm: GraphModule):
"""Check if an input can be quantized. The input can not be quantized if:
- The node does not output a float tensor or,
- The node outputs a large scalar.
"""
return not (
is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm)
)


def is_input_large_scalar(node: Node, gm: GraphModule):
"""Check if input is a large scalar value. So that we can skip quantization for the node
since histc op (in HistogramObserver) only works for values up to certain upper bound
"""
Expand Down Expand Up @@ -142,6 +189,7 @@ def convert_scalars_to_attrs(model: GraphModule) -> GraphModule:
"""
targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Tensor,
]
for n in model.graph.nodes:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def decorator(annotator: AnnotatorType):
max_pool2d_annotator,
mul_annotator,
sigmoid_annotator,
sub_annotator,
)
40 changes: 9 additions & 31 deletions backends/arm/quantizer/quantization_annotation/add_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

Expand All @@ -37,32 +34,13 @@ def _annotate_add(
if arm_quantizer_utils.is_annotated(add_node):
continue

input_act0 = add_node.args[0]
input_act_qspec = quantization_config.get_input_act_qspec()
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, add_node))

input_qspec_map = {}
if isinstance(input_act0, Node):
if arm_quantizer_utils.is_input_large_scalar(input_act0, gm):
continue
if arm_quantizer_utils.is_input_non_float_tensor(input_act0):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
if arm_quantizer_utils.is_input_large_scalar(input_act1, gm):
continue
if arm_quantizer_utils.is_input_non_float_tensor(input_act1):
continue
if input_act0 is not input_act1:
input_qspec_map[input_act1] = shared_with_input0_qspec
else:
input_qspec_map[input_act1] = input_act_qspec

add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=shared_with_input0_qspec,
_annotated=True,
input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec(
add_node, gm, quantization_config
)
if input_qspec_map is not None:
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_qspec,
_annotated=True,
)
return annotated_partitions
46 changes: 46 additions & 0 deletions backends/arm/quantizer/quantization_annotation/sub_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import operator
from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("sub")
def _annotate_sub(
gm: GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
sub_partitions = get_source_partitions(
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn
)
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values()))
annotated_partitions = []
for sub_partition in sub_partitions:
annotated_partitions.append(sub_partition.nodes)
sub_node = sub_partition.output_nodes[0]
if arm_quantizer_utils.is_annotated(sub_node):
continue

input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec(
sub_node, gm, quantization_config
)
if input_qspec_map is not None:
sub_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_qspec,
_annotated=True,
)
return annotated_partitions
Loading

0 comments on commit 11b2fcb

Please sign in to comment.