diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 07876a4a25..e19a102ee4 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -93,6 +93,22 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType: return ConfigPrecisionType.STATIC_QUANT + def _overwrite_precision(self, node: torch.fx.Node): + precision = self._detect_precision(node) + if precision not in self.enabled_precision_types: + # detected precision is not enabled, lets try to partition it as fp32 + if self.enabled_precision_types == [ConfigPrecisionType.FP32]: + # if only fp32 is enabled, then we can still partition fp32 gemms + # even with in a quantized graph + if precision in [ + ConfigPrecisionType.STATIC_QUANT, + ConfigPrecisionType.DYNAMIC_QUANT, + ]: + precision = ConfigPrecisionType.FP32 + logging.info(f"Overwriting precision, partitioning {node} as FP32") + return True, precision + return False, precision + def get_deps( self, node: torch.fx.Node, @@ -107,7 +123,7 @@ def get_deps( if precision not in self.supported_precision_types(): # detected precision but it is either disabled or not supported return (False, []) - + _, precision = self._overwrite_precision(node) valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) valid_weight, weight_deps = self._get_weight_deps(node, ep, precision) valid_act, act_deps = self._get_act_deps(node, ep, precision) @@ -193,7 +209,7 @@ def _get_bias_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] - if len(node.all_input_nodes) > 2 and self.bias_idx: + if len(node.all_input_nodes) > 2 and self.bias_idx is not None: bias_node = get_input_node(node, self.bias_idx) if bias_node: if not is_param_node(ep, bias_node): @@ -266,7 +282,14 @@ def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is on and we detected this as fp32, then we + # if force fp32_dynamic_linear is enabled, then we + # do not partition the weight node + return (True, []) + + # Since we are in Linear, we may assume that the weights are indeed static. + overwritten_linear_precision, new_precision = self._overwrite_precision(node) + if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision: + # if overwriting quantized precision to fp32, then we # do not partition the weight node return (True, []) diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 348e36bd0c..7d522b0053 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -8,14 +8,16 @@ import unittest from itertools import product -from typing import Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, ) -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner - +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackFloatingPointPartitioner, + XnnpackPartitioner, +) from executorch.backends.xnnpack.test.tester import Quantize, Tester from executorch.backends.xnnpack.test.tester.tester import ( Partition, @@ -672,3 +674,150 @@ def _test_groupwise_dq_linear( .serialize() .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) + + def _test_linear_overwrite_precision( + self, + make_module: Callable[[int, int], torch.nn.Module], + uses_bias: bool, + quant_type: str, + quant_node_checks: List[Dict[str, int]], + atol: float = 1e-03, + ): + """ + This test is to test the overwrite precision of linear op. + We will test partitioning, lowering, and running the quantized linear model as fp32 linear op. + When using legacy_mode, we will test we don't partition [add]mm given, + (1) We can't assume that weights are always static (non param). + (2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias. + (2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI. + """ + + in_sizes = [3, 4, 4] + input_sizes = [4, 37, 17] + output_sizes = [4, 17, 37] + + assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"] + per_channel = "per_channel" in quant_type + dynamic = "dynamic" in quant_type + quant_config = get_symmetric_quantization_config( + is_per_channel=per_channel, + is_dynamic=dynamic, + ) + # Using FP32 partitioner for this quantized graph + partitioner = XnnpackFloatingPointPartitioner() + + def get_qnode_checks(quant_node_checks, dialect): + d = {} + assert dialect in ["aten", "edge"] + if dialect == "aten": + d = { + f"torch.ops.quantized_decomposed.{op}": count + for op, count in quant_node_checks.items() + } + elif dialect == "edge": + d = { + f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace( + ".", "_" + ): count + for op, count in quant_node_checks.items() + } + assert len(d) == len(quant_node_checks) + return d + + for i, _ in enumerate(in_sizes): + torch._dynamo.reset() + in_size = int(in_sizes[i]) + input_size = int(input_sizes[i]) + output_size = int(output_sizes[i]) + input_shape = [in_size] + [input_size] + module = make_module(input_size, output_size).eval() + inputs = (torch.randn(input_shape),) + + addmm_op_str = ( + "executorch_exir_dialects_edge__ops_aten_addmm_default" + if uses_bias + else "executorch_exir_dialects_edge__ops_aten_mm_default" + ) + linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default" + + for legacy_mode in (True, False): + tester = ( + Tester(module, inputs) + .quantize(Quantize(quantization_config=quant_config)) + .export() + .dump_artifact() + .check_count(get_qnode_checks(quant_node_checks, "aten")) + ) + + if legacy_mode: + tester.to_edge() + tester.partition(Partition(partitioner=partitioner)) + # We don't expect [add]mm to be partitioned + tester.check([addmm_op_str]) + else: + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower(partitioners=[partitioner]) + ) + # We do expect linear to be partitioned + tester.check_not([linear_op_str]) + + # For legacy mode, fp32 permute_copy gets partitioned. (just a side effect) + # For new mode, fp32 linear gets partitioned. + tester.check_count( + {"torch.ops.higher_order.executorch_call_delegate": 1} + ) + + # Typically, we would not see any quantized ops in the graph. + # But here we shouldn't partition these. + tester.check_count(get_qnode_checks(quant_node_checks, "edge")) + + # TODO: Need to figure out how to load quantized ops in pybindings. + # tester.to_executorch() + # tester.serialize() + # tester.run_method_and_compare_outputs( + # qtol=bool(quant_config), atol=atol + # ) + + def test_qs8_as_fp32(self): + for use_bias in (True, False): + self._test_linear_overwrite_precision( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + use_bias, + "per_tensor", + quant_node_checks={ + "quantize_per_tensor.default": 2, # 1: act, 1: output + "dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output + }, + ) + + def test_qc8_as_fp32(self): + for use_bias in (True, False): + self._test_linear_overwrite_precision( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + use_bias, + "per_channel", + quant_node_checks={ + "quantize_per_tensor.default": 2, # 1: act, 1: output + "dequantize_per_tensor.default": 2, # 1: act, 1: output + "dequantize_per_channel.default": 1, # 1: weight + }, + ) + + def test_qd8_as_fp32(self): + for use_bias in (True, False): + self._test_linear_overwrite_precision( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + use_bias, + "per_channel_dynamic", + quant_node_checks={ + "quantize_per_tensor.tensor": 1, # 1: act + "dequantize_per_tensor.tensor": 1, # 1: act + "dequantize_per_channel.default": 1, # 1: weight + }, + )