From 7924942b5fb3fb3f461905e71565ab47e159a070 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 13 Dec 2024 14:40:41 -0800 Subject: [PATCH] Allow addmm and mm to call dynamic fp32 kernels Xnnpack Differential Revision: D66898281 Pull Request resolved: https://github.com/pytorch/executorch/pull/7232 --- .../xnnpack/partition/config/gemm_configs.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index cbcb14899d..07876a4a25 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -337,6 +337,17 @@ def __init__(self, **kwargs): self.src_partitions = None self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear] + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + # TODO(maxren, T210537195): + if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: + # if force fp32_dynamic_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def get_deps( self, node: torch.fx.Node, @@ -436,6 +447,16 @@ def __init__(self, **kwargs): self.weight_idx = 1 self.act_idx = 0 + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: + # if force fp32_dynamic_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def supported_precision_types(self): return [ ConfigPrecisionType.FP32,