Skip to content

Commit

Permalink
Fix bugs blocking flipping the default layout constraint for custom o…
Browse files Browse the repository at this point in the history
…ps (pytorch#135391)

Fixes two things:
- For regular PyTorch ops, the default layout constraint tag is always
flexible_layout. This was a bug with pytorch#135238
- Mark the new quantized _wrapped_linear_prepack ops as flexible_layout.
  The metas for these are incorrect, I didn't want to fix them (and
  changing the default requires the metas actually be correct).

Test Plan:
- The next PR up in the stack. The PRs are split because the next one is
  riskier.

foo

Pull Request resolved: pytorch#135391
Approved by: https://github.com/albanD
  • Loading branch information
zou3519 authored and pytorchmergebot committed Sep 9, 2024
1 parent a13c118 commit 5f7d956
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"), {at::Tag::flexible_layout});
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"), {at::Tag::flexible_layout});
}

TORCH_LIBRARY(onednn, m) {
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def get_layout_constraint_tag(fn):
for tag in tags_by_priority:
if tag in fn.tags:
return tag
if torch._library.utils.is_builtin(fn):
return torch._C.Tag.flexible_layout
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)


Expand Down

0 comments on commit 5f7d956

Please sign in to comment.