Skip to content

Commit

Permalink
Run decompositions before the quantizer
Browse files Browse the repository at this point in the history
Differential Revision: D66461406

Pull Request resolved: pytorch#7111
  • Loading branch information
mcremon-meta authored Dec 2, 2024
1 parent 2326fff commit 0a12e33
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
to_edge,
)
from executorch.exir.pass_base import PassResult
from torch._inductor.decomposition import remove_decompositions
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

Expand Down Expand Up @@ -58,16 +59,33 @@ def convert_pt2(
Returns a GraphModule with the converted model.
"""

# Get default decompositions
decomp_table = torch.export.default_decompositions()
# Select ops to keep
ops_to_keep = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
]
# Remove decompositions for the ops we want to keep
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
remove_decompositions(decomp_table, ops_to_keep)
# Export with dynamo
model_gm = torch.export.export_for_training(model, inputs).module()
model_gm = (
torch.export.export_for_training(model, inputs)
.run_decompositions(decomp_table)
.module()
)

if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
if model_gm_has_SDPA(model_gm):
# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
DecomposeScaledDotProductAttention(False)(model_gm)

# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
# for details).
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
assert result is not None
model_gm = result.graph_module

Expand Down

0 comments on commit 0a12e33

Please sign in to comment.