Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConvInteger: fix parsing for x_zero_point and w_zero_point #3763

Merged
merged 24 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2e23878
Fix memory_coloring pass when MIGRAPHX_NSTREAMS > 2
kahmed10 Jan 11, 2025
be58a75
update copyright and add test case
kahmed10 Jan 15, 2025
031d500
update test copyright
kahmed10 Jan 15, 2025
eb3d71e
Merge branch 'develop' of https://github.com/ROCm/AMDMIGraphX into me…
kahmed10 Jan 15, 2025
582b281
fix convinteger bias parsing and updated test
kahmed10 Jan 16, 2025
a017473
formatting
kahmed10 Jan 16, 2025
ac12de4
update onnx file
kahmed10 Jan 16, 2025
cd10d9a
Merge branch 'develop' of https://github.com/ROCm/AMDMIGraphX into pa…
kahmed10 Jan 16, 2025
276d9d4
cleanup qparam_broadcast_op function
kahmed10 Jan 16, 2025
9bf587b
fix test channels and fix simplify_algebra find_inner_broadcasts for …
kahmed10 Jan 16, 2025
c956ea7
update convinteger_bias_test
kahmed10 Jan 16, 2025
0d68d62
formatting
kahmed10 Jan 16, 2025
fd0931f
update license
kahmed10 Jan 16, 2025
a5c22fa
Merge branch 'develop' into parse_qconv_bias_fix
kahmed10 Jan 16, 2025
ada1625
reuse smaller test case for verify
kahmed10 Jan 17, 2025
722b571
Merge branch 'develop' of https://github.com/ROCm/AMDMIGraphX into pa…
kahmed10 Jan 17, 2025
d90aae0
update license year
kahmed10 Jan 17, 2025
f388ed2
add missing onnx file
kahmed10 Jan 17, 2025
4c3f92e
fix filepath
kahmed10 Jan 17, 2025
1ce1f95
Merge branch 'develop' of https://github.com/ROCm/AMDMIGraphX into pa…
kahmed10 Jan 17, 2025
3ab07bc
update verify_onnx tests
kahmed10 Jan 20, 2025
f5a5f09
formatting
kahmed10 Jan 20, 2025
23013ae
revert build script
kahmed10 Jan 20, 2025
314b6ba
fix licensing
kahmed10 Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -141,17 +141,14 @@ struct parse_convolution : op_parser<parse_convolution>
return all_zeros;
}

static auto
static migraphx::operation
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
if(qparam->get_shape().elements() == 1)
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}

static instruction_ref handle_quant_bias(const operation& op,
Expand All @@ -162,27 +159,37 @@ struct parse_convolution : op_parser<parse_convolution>
const instruction_ref& w_zp,
onnx_parser::node_info& info)
{
// to handle the bias, apply the following transformation:
// conv(x-x_zp,w-w_zp) = conv(x,w) - conv(x_zp,w) - conv(x,w_zp) + conv(x_zp,w_zp)
instruction_ref ret = input;

// multibroadcast (or broadcast) zero points according to spec
// x_zp should be a scalar or literal with one element
// w_zp can be either a single element or a 1d tensor with size out_channels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: Good to add the comment here!

migraphx::operation x_zp_bc =
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}});
migraphx::operation w_zp_bc = qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0);

if(not is_symmetric_zero_point(x_zp))
{
auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto out_zp_1 = info.add_instruction(op, x_zp_mb, weights);
ret = info.add_common_op("sub", ret, out_zp_1);
}

if(not is_symmetric_zero_point(w_zp))
{
auto out_zp_2 = info.add_common_op(op.name(), x, w_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);
auto out_zp_2 = info.add_instruction(op, x, w_zp_mb);
ret = info.add_common_op("sub", ret, out_zp_2);
}

if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp)))
{
auto x_zp_bc =
info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp);
auto w_zp_bc = info.add_instruction(
qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);

auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc);
auto out_zp_3 = info.add_instruction(op, x_zp_mb, w_zp_mb);

ret = info.add_common_op("add", ret, out_zp_3);
}
Expand Down
20 changes: 11 additions & 9 deletions test/onnx/convinteger_dual_bias_test.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ B
strides@@ convinteger_dual_bias_testZ
0





Z



Z
1





Z

Z
2


Expand All @@ -28,7 +30,7 @@ B
b
4





B

B
6 changes: 3 additions & 3 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,11 +1686,11 @@ def convinteger_bias_test():

@onnx_test()
def convinteger_dual_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 5, 5])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 2, 2])
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 10, 10])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 2, 3, 3])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
w = helper.make_tensor_value_info('3', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('4', TensorProto.INT32, [1, 1, 4, 4])
out = helper.make_tensor_value_info('4', TensorProto.INT32, [2, 4, 8, 8])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2', '3'],
Expand Down
32 changes: 14 additions & 18 deletions test/onnx/parse/convinteger_dual_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,41 +28,37 @@ TEST_CASE(convinteger_dual_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 2, 2}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 10, 10}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 2, 3, 3}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});
auto weight_bias = mm->add_parameter("3", {migraphx::shape::int8_type, {1}, {1}});

auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weight);

auto mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant_db_w =
auto quant_mb_w =
mm->add_instruction(migraphx::make_op("quant_convolution"), mbcast_data_bias, weight);

auto quant_mb_w = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_db_w);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_mb_w);

auto mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), weight_bias);
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);

auto quant_d_wb =
auto quant_md_wb =
mm->add_instruction(migraphx::make_op("quant_convolution"), data, mbcast_weight_bias);

auto quant_md_wb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_d_wb);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_md_wb);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
auto bcast_weight_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", weight->get_shape().lens()}}), weight_bias);
mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);
auto bias_quant = mm->add_instruction(
migraphx::make_op("quant_convolution"), bcast_data_bias, bcast_weight_bias);
migraphx::make_op("quant_convolution"), mbcast_data_bias, mbcast_weight_bias);

mm->add_instruction(migraphx::make_op("add"), quant, bias_quant);

Expand Down
Loading