diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 937e3e39bc..6b3a023181 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -28,6 +28,7 @@ to_edge, ) from executorch.exir.pass_base import PassResult +from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,16 +59,33 @@ def convert_pt2( Returns a GraphModule with the converted model. """ + # Get default decompositions + decomp_table = torch.export.default_decompositions() + # Select ops to keep + ops_to_keep = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.layer_norm.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + ] + # Remove decompositions for the ops we want to keep + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any + remove_decompositions(decomp_table, ops_to_keep) # Export with dynamo - model_gm = torch.export.export_for_training(model, inputs).module() + model_gm = ( + torch.export.export_for_training(model, inputs) + .run_decompositions(decomp_table) + .module() + ) - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] + if model_gm_has_SDPA(model_gm): # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] + DecomposeScaledDotProductAttention(False)(model_gm) # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 # for details). - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) assert result is not None model_gm = result.graph_module