diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index cbcb14899d..07876a4a25 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -337,6 +337,17 @@ def __init__(self, **kwargs): self.src_partitions = None self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear] + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + # TODO(maxren, T210537195): + if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: + # if force fp32_dynamic_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def get_deps( self, node: torch.fx.Node, @@ -436,6 +447,16 @@ def __init__(self, **kwargs): self.weight_idx = 1 self.act_idx = 0 + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: + # if force fp32_dynamic_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def supported_precision_types(self): return [ ConfigPrecisionType.FP32,