From 0a12e33d22a3d44d1aa2af5f0d0673d45b962553 Mon Sep 17 00:00:00 2001 From: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:38:29 -0800 Subject: [PATCH] Run decompositions before the quantizer Differential Revision: D66461406 Pull Request resolved: https://github.com/pytorch/executorch/pull/7111 --- backends/cadence/aot/compiler.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 937e3e39bc..6b3a023181 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -28,6 +28,7 @@ to_edge, ) from executorch.exir.pass_base import PassResult +from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,16 +59,33 @@ def convert_pt2( Returns a GraphModule with the converted model. """ + # Get default decompositions + decomp_table = torch.export.default_decompositions() + # Select ops to keep + ops_to_keep = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.layer_norm.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + ] + # Remove decompositions for the ops we want to keep + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any + remove_decompositions(decomp_table, ops_to_keep) # Export with dynamo - model_gm = torch.export.export_for_training(model, inputs).module() + model_gm = ( + torch.export.export_for_training(model, inputs) + .run_decompositions(decomp_table) + .module() + ) - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] + if model_gm_has_SDPA(model_gm): # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] + DecomposeScaledDotProductAttention(False)(model_gm) # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 # for details). - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) assert result is not None model_gm = result.graph_module