Skip to content

Commit

Permalink
[ts_converter][reland] Add support for LinearOpContext and Conv2dOpCo…
Browse files Browse the repository at this point in the history
…ntext in quantization pass (pytorch#133622)

Summary: Reland of D60871242

Test Plan: CI

Differential Revision: D61352600

Pull Request resolved: pytorch#133622
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
angelayi authored and pytorchmergebot committed Aug 16, 2024
1 parent 1653f77 commit a1a869f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 41 deletions.
48 changes: 25 additions & 23 deletions aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ TORCH_LIBRARY(xnnpack, m) {
[](SerializationTypeLinearPrePack state)
-> c10::intrusive_ptr<LinearOpContext> { // __setstate__
return createLinearClampPrePackOpContext(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)));
});
std::get<0>(state),
std::get<1>(state),
std::get<2>(state),
std::get<3>(state));
})
.def("unpack", &LinearOpContext::unpack);

m.class_<Conv2dOpContext>(TORCH_SELECTIVE_CLASS("Conv2dOpContext"))
.def_pickle(
Expand All @@ -37,15 +38,16 @@ TORCH_LIBRARY(xnnpack, m) {
[](SerializationTypeConv2dPrePack state)
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
return createConv2dClampPrePackOpContext(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
std::get<0>(state),
std::get<1>(state),
std::get<2>(state),
std::get<3>(state),
std::get<4>(state),
std::get<5>(state),
std::get<6>(state),
std::get<7>(state));
})
.def("unpack", &Conv2dOpContext::unpack);

m.class_<TransposeConv2dOpContext>(TORCH_SELECTIVE_CLASS("TransposeConv2dOpContext"))
.def_pickle(
Expand All @@ -56,15 +58,15 @@ TORCH_LIBRARY(xnnpack, m) {
[](SerializationTypeTransposeConv2dPrePack state)
-> c10::intrusive_ptr<TransposeConv2dOpContext> { // __setstate__
return createConv2dTransposeClampPrePackOpContext(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)),
std::move(std::get<8>(state)));
std::get<0>(state),
std::get<1>(state),
std::get<2>(state),
std::get<3>(state),
std::get<4>(state),
std::get<5>(state),
std::get<6>(state),
std::get<7>(state),
std::get<8>(state));
});

}
Expand Down
17 changes: 17 additions & 0 deletions test/export/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,23 @@ def fuse_model(self):
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
self._check_tensor_list_equal(orig_out, ep_out)

def test_ts2ep_convert_quantized_model_with_opcontext(self):
class M(torch.nn.Module):
def __init__(self, linear_op):
super().__init__()
self.linear_op = linear_op

def forward(self, x):
x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op)
return x

linear_op = torch.ops.prepacked.linear_clamp_prepack(
torch.randn(10, 10), torch.randn(10)
)
m = M(linear_op)
inp = (torch.randn(1, 10),)
self._check_equal_ts_ep_converter(m, inp, ["script"])


if __name__ == "__main__":
run_tests()
Original file line number Diff line number Diff line change
Expand Up @@ -425,22 +425,27 @@ def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
Transformation for functions under prepacked namespace, where they share
the same handling logic that [...]OpContext contains all parameters.
"""
# TODO: Expose weights and bias from [...]OpContext.
# Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp#L31 # noqa: B950
# to_standard_op = {
# "conv2d_clamp_run": torch.ops.aten.conv2d,
# "linear_clamp_run": torch.ops.aten.linear,
# }
assert isinstance(node.target, torch._ops.OpOverload)
opname, args = node.target._opname, node.args
op_f = None
if opname == "conv2d_clamp_run":
op_f = torch.ops.aten.conv2d
elif opname == "linear_clamp_run":
op_f = torch.ops.aten.linear
else:
raise RuntimeError(f"Invalid operator {opname}")

assert isinstance(args[1], torch.fx.Node)
so = get_script_object(gm, args[1])

# to_param_unpack_op = {
# "conv2d_clamp_run": torch.ops.prepacked.unpack_prepacked_sizes_conv2d,
# "linear_clamp_run": torch.ops.prepacked.unpack_prepacked_sizes_linear,
# }
func_args = []
func_args += [args[0]]
func_args += so.unpack()[:2] # type: ignore[attr-defined]
if opname == "conv2d_clamp_run":
func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]

# op_res_node = gm.graph.call_function(
# to_standard_op[opname], new_args
# )
# return op_res_node, _SCALE, _ZERO_POINT
op_res_node = gm.graph.call_function(op_f, tuple(func_args))
return op_res_node


def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
Expand Down Expand Up @@ -481,8 +486,6 @@ def fx_transform_quantized_op_to_standard_op(
"cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
"hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
"batch_norm2d.default": _transform_batch_norm,
# "conv2d_clamp_run.default": _transform_prepacked_op,
# "linear_clamp_run.default": _transform_prepacked_op,
"mul.Scalar": _transform_scalar_arithmetic,
"add.Scalar": _transform_scalar_arithmetic,
}
Expand Down Expand Up @@ -551,6 +554,9 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
to the operator by converting them to a getattr fx.node.
For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
without the need of doing de/quantization.
Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
data type, which is the same across the entire program, but it only shows up in the very first quantization
call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
Expand All @@ -565,14 +571,19 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
if isinstance(node.target, OpOverload):
with gm.graph.inserting_before(node):
namespace, opname = node.target.namespace, node.target._opname
if namespace in ["quantized", "prepacked"] and opname not in [
if namespace == "quantized" and opname not in [
"conv_prepack",
"linear_prepack",
]:
quantized = True
fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
node.replace_all_uses_with(fx_node)
last_quantized_node = fx_node
elif namespace == "prepacked":
quantized = True
fx_node = _transform_prepacked_op(gm, node)
node.replace_all_uses_with(fx_node)
last_quantized_node = fx_node
elif namespace == "aten" and opname == "quantize_per_tensor":
inp_node, scale_node, zero_point_node, dtype_node = node.args
dtype_node = fx_enum_to_dtype(gm, dtype_node)
Expand Down Expand Up @@ -608,7 +619,8 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
last_quantized_node = node

# Post-processing again to remove legacy ScriptObjects and quantizated tensors
# stored as attributes or in the buffer.
# stored as attributes or in the buffer. This is used to clean up the GraphModule
# to not trigger tracing errors like missing __obj_flatten__ functions.
def _clean_attr(mod: torch.nn.Module):
for submod in mod.modules():
attr_names_to_clean = set()
Expand Down

0 comments on commit a1a869f

Please sign in to comment.