Skip to content

Commit

Permalink
Include FuseDequantLinearPass() in vulkan_preprocess (pytorch#6168)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#6168

## Context

Include `FuseDequantLinearPass` as a part of `vulkan_preprocess`, so that fusing the quant/dequant nodes added by `VulkanQuantizer` can be done as part of the lowering process.
ghstack-source-id: 247613964
exported-using-ghexport

Reviewed By: jorgep31415

Differential Revision: D64249613

fbshipit-source-id: 6dbc88e0c062f8f9c41eeb09bf8ba02d9096d009
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Oct 11, 2024
1 parent 10f83d6 commit 236e60d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ runtime.python_library(
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
"//executorch/backends/transforms:fuse_conv_with_clamp",
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def __contains__(self, op):

PRIM_OPS = [
operator.getitem,
# Quantization related ops will be fused via graph passes
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
]

SUPPORTS_DYNAMIC_SHAPE = [
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FuseBatchNormWithConvPass,
)
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
Expand Down Expand Up @@ -59,6 +60,7 @@ def preprocess( # noqa: C901
passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
Expand Down

0 comments on commit 236e60d

Please sign in to comment.