Skip to content

Commit

Permalink
Allow partitioning quantized linear for FP32-only partition
Browse files Browse the repository at this point in the history
Differential Revision: D67011716

Pull Request resolved: pytorch#7284
  • Loading branch information
digantdesai authored Dec 18, 2024
1 parent 18142f7 commit 884d16d
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 6 deletions.
29 changes: 26 additions & 3 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:

return ConfigPrecisionType.STATIC_QUANT

def _overwrite_precision(self, node: torch.fx.Node):
precision = self._detect_precision(node)
if precision not in self.enabled_precision_types:
# detected precision is not enabled, lets try to partition it as fp32
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
# if only fp32 is enabled, then we can still partition fp32 gemms
# even with in a quantized graph
if precision in [
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]:
precision = ConfigPrecisionType.FP32
logging.info(f"Overwriting precision, partitioning {node} as FP32")
return True, precision
return False, precision

def get_deps(
self,
node: torch.fx.Node,
Expand All @@ -107,7 +123,7 @@ def get_deps(
if precision not in self.supported_precision_types():
# detected precision but it is either disabled or not supported
return (False, [])

_, precision = self._overwrite_precision(node)
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
valid_act, act_deps = self._get_act_deps(node, ep, precision)
Expand Down Expand Up @@ -193,7 +209,7 @@ def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
gemm_deps = []
if len(node.all_input_nodes) > 2 and self.bias_idx:
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
bias_node = get_input_node(node, self.bias_idx)
if bias_node:
if not is_param_node(ep, bias_node):
Expand Down Expand Up @@ -266,7 +282,14 @@ def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
# if force fp32_dynamic_linear is enabled, then we
# do not partition the weight node
return (True, [])

# Since we are in Linear, we may assume that the weights are indeed static.
overwritten_linear_precision, new_precision = self._overwrite_precision(node)
if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision:
# if overwriting quantized precision to fp32, then we
# do not partition the weight node
return (True, [])

Expand Down
155 changes: 152 additions & 3 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import unittest

from itertools import product
from typing import Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackFloatingPointPartitioner,
XnnpackPartitioner,
)
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from executorch.backends.xnnpack.test.tester.tester import (
Partition,
Expand Down Expand Up @@ -672,3 +674,150 @@ def _test_groupwise_dq_linear(
.serialize()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

def _test_linear_overwrite_precision(
self,
make_module: Callable[[int, int], torch.nn.Module],
uses_bias: bool,
quant_type: str,
quant_node_checks: List[Dict[str, int]],
atol: float = 1e-03,
):
"""
This test is to test the overwrite precision of linear op.
We will test partitioning, lowering, and running the quantized linear model as fp32 linear op.
When using legacy_mode, we will test we don't partition [add]mm given,
(1) We can't assume that weights are always static (non param).
(2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias.
(2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI.
"""

in_sizes = [3, 4, 4]
input_sizes = [4, 37, 17]
output_sizes = [4, 17, 37]

assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"]
per_channel = "per_channel" in quant_type
dynamic = "dynamic" in quant_type
quant_config = get_symmetric_quantization_config(
is_per_channel=per_channel,
is_dynamic=dynamic,
)
# Using FP32 partitioner for this quantized graph
partitioner = XnnpackFloatingPointPartitioner()

def get_qnode_checks(quant_node_checks, dialect):
d = {}
assert dialect in ["aten", "edge"]
if dialect == "aten":
d = {
f"torch.ops.quantized_decomposed.{op}": count
for op, count in quant_node_checks.items()
}
elif dialect == "edge":
d = {
f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace(
".", "_"
): count
for op, count in quant_node_checks.items()
}
assert len(d) == len(quant_node_checks)
return d

for i, _ in enumerate(in_sizes):
torch._dynamo.reset()
in_size = int(in_sizes[i])
input_size = int(input_sizes[i])
output_size = int(output_sizes[i])
input_shape = [in_size] + [input_size]
module = make_module(input_size, output_size).eval()
inputs = (torch.randn(input_shape),)

addmm_op_str = (
"executorch_exir_dialects_edge__ops_aten_addmm_default"
if uses_bias
else "executorch_exir_dialects_edge__ops_aten_mm_default"
)
linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default"

for legacy_mode in (True, False):
tester = (
Tester(module, inputs)
.quantize(Quantize(quantization_config=quant_config))
.export()
.dump_artifact()
.check_count(get_qnode_checks(quant_node_checks, "aten"))
)

if legacy_mode:
tester.to_edge()
tester.partition(Partition(partitioner=partitioner))
# We don't expect [add]mm to be partitioned
tester.check([addmm_op_str])
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower(partitioners=[partitioner])
)
# We do expect linear to be partitioned
tester.check_not([linear_op_str])

# For legacy mode, fp32 permute_copy gets partitioned. (just a side effect)
# For new mode, fp32 linear gets partitioned.
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": 1}
)

# Typically, we would not see any quantized ops in the graph.
# But here we shouldn't partition these.
tester.check_count(get_qnode_checks(quant_node_checks, "edge"))

# TODO: Need to figure out how to load quantized ops in pybindings.
# tester.to_executorch()
# tester.serialize()
# tester.run_method_and_compare_outputs(
# qtol=bool(quant_config), atol=atol
# )

def test_qs8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_tensor",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output
},
)

def test_qc8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_channel.default": 1, # 1: weight
},
)

def test_qd8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel_dynamic",
quant_node_checks={
"quantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_channel.default": 1, # 1: weight
},
)

0 comments on commit 884d16d

Please sign in to comment.