Skip to content

Commit

Permalink
Convert more NodeVisitors to folding DQ/Q pass usage
Browse files Browse the repository at this point in the history
Signed-off-by: Per Åstrand <[email protected]>
Change-Id: I9201d8bafd543204b697c7276d6929ad3aa09f25
  • Loading branch information
per authored and freddan80 committed Dec 16, 2024
1 parent eae61f7 commit e24d503
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 90 deletions.
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
]
)
)
Expand Down
100 changes: 87 additions & 13 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,118 @@

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common
from executorch.backends.arm.tosa_specification import TosaSpecification


@register_node_visitor
class AvgPool2dVisitor(NodeVisitor):
class AvgPool2dVisitor_0_80_BI(NodeVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
input_zp: int,
output_zp: int,
accumulator_type,
) -> None:
input_tensor = inputs[0]

kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special
try:
pad_size_list = inputs[3].special
except IndexError:
pad_size_list = [0, 0, 0, 0]

build_avg_pool_2d_common(
node,
tosa_graph,
input_tensor,
kernel_size_list,
stride_size_list,
pad_size_list,
is_quant_node,
output,
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size_list,
stride=stride_size_list,
pad=pad_size_list,
input_zp=input_zp,
output_zp=output_zp,
accum_dtype=accumulator_type,
)

tosa_graph.addOperator(
ts.TosaOp.Op().AVG_POOL2D,
[input_tensor.name],
[output.name],
attr,
)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].zp

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].zp

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"

if inputs[0].dtype == ts.DType.INT8:
super().define_node(node, tosa_graph, inputs, output, is_quant_node)

if inputs[0].dtype == ts.DType.FP32:
accumulator_type = ts.DType.FP32
# Initilize zero point to zero.
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)
5 changes: 5 additions & 0 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
from serializer.tosa_serializer import TosaOp

Expand All @@ -21,6 +22,10 @@
class BatchNormVisitor(NodeVisitor):
target = "aten._native_batch_norm_legit_no_training.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand Down
48 changes: 26 additions & 22 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
build_rescale_conv_output,
get_quant_arg_downstream,
get_quant_arg_upstream,
)
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape

from serializer.tosa_serializer import TosaOp
Expand Down Expand Up @@ -57,9 +57,6 @@ def define_node(
) -> None:
input, weight, bias, stride, pad, dilation, _, _, group = inputs

# Currently only int8 is supported in quantized types.
actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype

# Get the attributes of convolution.
attr = ts.TosaSerializerAttribute()
pad_attr = [val for val in pad.special for _ in (0, 1)]
Expand All @@ -82,9 +79,11 @@ def define_node(
dilation_attr[1],
)

input_zp = (
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
)
input_zp = 0
if inputs[0].dtype == ts.DType.INT8:
# int8 input requires quantization information
input_qparams = get_input_qparams(node)
input_zp = input_qparams[0].zp

attr.ConvAttribute(
pad=pad_attr,
Expand All @@ -100,16 +99,22 @@ def define_node(
# Create a zero bias tensor if not presented
out_channels = weight.shape[0]
bias_name = "bias" + node.name.split("default", 1)[1]
bias_type = output.dtype
if output.dtype == ts.DType.INT8:
# Conv is quantized to int8, but the TOSA operator has
# output type int32, and the bias must be the same type
# as the TOSA output type
bias_type = ts.DType.INT32
bias = tosa_graph.addConst(
[out_channels],
ts.DType.INT32 if is_quant_node else output.dtype,
bias_type,
[0] * out_channels,
name=bias_name,
)

# The output type is int32 when input type is int8.
conv2d_output_name = output.name
if is_quant_node:
if output.dtype == ts.DType.INT8:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)
Expand All @@ -132,7 +137,7 @@ def define_node(

weight_reshaped = tosa_graph.addIntermediate(
weight_post_shape,
ts.DType.INT8 if is_quant_node else weight.dtype,
weight.dtype,
)
build_reshape(
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
Expand All @@ -157,20 +162,19 @@ def define_node(

# For quantized convolution, rescale the output value back to the same
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
if inputs[0].dtype == ts.DType.INT8:
# Get scale_factor from input, weight, and output.
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
output_qargs = get_quant_arg_downstream(list(node.users)[0])

input_scale = input_qparams[0].scale
weight_scale = input_qparams[1].scale
output_qargs = get_output_qparams(node)
build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
conv2d_res,
output.name,
actual_out_type,
output.dtype,
input_scale,
weight_scale,
output_qargs.scale,
output_qargs.zp,
output_qargs[0].scale,
output_qargs[0].zp,
)
6 changes: 6 additions & 0 deletions backends/arm/operators/op_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp

Expand All @@ -21,6 +22,11 @@
class DivVisitor(NodeVisitor):
target = "aten.div.Tensor"

# Only supported for MI
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import (
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_downstream,
get_quant_arg_upstream,
)
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import serializer.tosa_serializer as ts
import torch
import torch.fx
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_upstream,
get_quantized_node_output_dtype,
is_node_quantized,
)
Expand Down Expand Up @@ -110,8 +112,10 @@ def process_quantized_bias(
_,
) = consumer_node.all_input_nodes

input_node_scale = get_quant_arg_upstream(input_node).scale
weight_node_scale = get_quant_arg_upstream(weight_node).scale
input_qargs = get_input_qparams(consumer_node)

input_node_scale = input_qargs[0].scale
weight_node_scale = input_qargs[1].scale
bias_values_quantized = (
(parameter_values / (input_node_scale * weight_node_scale))
.round()
Expand Down
Loading

0 comments on commit e24d503

Please sign in to comment.