diff --git a/src/onnx/parse_convolution.cpp b/src/onnx/parse_convolution.cpp index d8ba3d498b5..190c69a0ffe 100644 --- a/src/onnx/parse_convolution.cpp +++ b/src/onnx/parse_convolution.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -141,17 +141,14 @@ struct parse_convolution : op_parser return all_zeros; } - static auto + static migraphx::operation qparam_broadcast_op(instruction_ref qparam, std::vector lens, std::size_t axis) { - if(qparam->get_shape().scalar()) + if(qparam->get_shape().elements() == 1) { return migraphx::make_op("multibroadcast", {{"out_lens", lens}}); } - else - { - return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}}); - } + return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}}); } static instruction_ref handle_quant_bias(const operation& op, @@ -162,27 +159,37 @@ struct parse_convolution : op_parser const instruction_ref& w_zp, onnx_parser::node_info& info) { + // to handle the bias, apply the following transformation: + // conv(x-x_zp,w-w_zp) = conv(x,w) - conv(x_zp,w) - conv(x,w_zp) + conv(x_zp,w_zp) instruction_ref ret = input; + + // multibroadcast (or broadcast) zero points according to spec + // x_zp should be a scalar or literal with one element + // w_zp can be either a single element or a 1d tensor with size out_channels + migraphx::operation x_zp_bc = + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}); + migraphx::operation w_zp_bc = qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0); + if(not is_symmetric_zero_point(x_zp)) { - auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights); + auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp); + auto out_zp_1 = info.add_instruction(op, x_zp_mb, weights); ret = info.add_common_op("sub", ret, out_zp_1); } if(not is_symmetric_zero_point(w_zp)) { - auto out_zp_2 = info.add_common_op(op.name(), x, w_zp); + auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp); + auto out_zp_2 = info.add_instruction(op, x, w_zp_mb); ret = info.add_common_op("sub", ret, out_zp_2); } if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp))) { - auto x_zp_bc = - info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp); - auto w_zp_bc = info.add_instruction( - qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp); + auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp); + auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp); - auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc); + auto out_zp_3 = info.add_instruction(op, x_zp_mb, w_zp_mb); ret = info.add_common_op("add", ret, out_zp_3); } diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 4d3d59a0364..469d1dd50d6 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -761,6 +761,8 @@ struct find_inner_broadcast void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; + if(ins->get_operator().name() == "layout") + return; const auto& broadcasts = ins->inputs(); if(broadcasts.empty()) return; diff --git a/test/onnx/convinteger_bias_test.onnx b/test/onnx/convinteger_bias_test.onnx index e7d66aef429..a1dcc4b676a 100644 --- a/test/onnx/convinteger_bias_test.onnx +++ b/test/onnx/convinteger_bias_test.onnx @@ -7,13 +7,13 @@ strides@@ convinteger_bias_testZ 0  - +    Z 1  - +   Z @@ -23,7 +23,7 @@ b 3  -  +  B \ No newline at end of file diff --git a/test/onnx/convinteger_dual_bias_simple_test.onnx b/test/onnx/convinteger_dual_bias_simple_test.onnx new file mode 100644 index 00000000000..0b7b6329af2 --- /dev/null +++ b/test/onnx/convinteger_dual_bias_simple_test.onnx @@ -0,0 +1,34 @@ + !convinteger_dual_bias_simple_test:à +B +0 +1 +2 +34" ConvInteger* + dilations@@ * +strides@@ !convinteger_dual_bias_simple_testZ +0 + + + + +Z +1 + + + + +Z +2 + + +Z +3 + + +b +4 + + + + +B \ No newline at end of file diff --git a/test/onnx/convinteger_dual_bias_test.onnx b/test/onnx/convinteger_dual_bias_test.onnx index 4b166872b7f..4ad76aa8ef5 100644 --- a/test/onnx/convinteger_dual_bias_test.onnx +++ b/test/onnx/convinteger_dual_bias_test.onnx @@ -8,16 +8,18 @@ B strides@@ convinteger_dual_bias_testZ 0  - +  - -Z + + + +Z 1  - +  - -Z + +Z 2  @@ -28,7 +30,7 @@ B b 4  - - +  -B \ No newline at end of file + +B \ No newline at end of file diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index b1158c3cd08..880aac4fed0 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -1670,10 +1670,10 @@ def convinteger_no_bias_uint8_test(): @onnx_test() def convinteger_bias_test(): - x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32]) - y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5]) + x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 32, 32]) + y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 5, 5]) z = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) - out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28]) + out = helper.make_tensor_value_info('3', TensorProto.INT32, [2, 4, 28, 28]) node = onnx.helper.make_node('ConvInteger', inputs=['0', '1', '2'], @@ -1686,6 +1686,23 @@ def convinteger_bias_test(): @onnx_test() def convinteger_dual_bias_test(): + x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 10, 10]) + y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 3, 3]) + z = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) + w = helper.make_tensor_value_info('3', TensorProto.INT8, [1]) + out = helper.make_tensor_value_info('4', TensorProto.INT32, [2, 4, 8, 8]) + + node = onnx.helper.make_node('ConvInteger', + inputs=['0', '1', '2', '3'], + outputs=['4'], + dilations=[1, 1], + strides=[1, 1]) + + return ([node], [x, y, z, w], [out]) + + +@onnx_test() +def convinteger_dual_bias_simple_test(): x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 5, 5]) y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 2, 2]) z = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) diff --git a/test/onnx/parse/convinteger_bias_test.cpp b/test/onnx/parse/convinteger_bias_test.cpp index 38d9c42412f..22534478b36 100644 --- a/test/onnx/parse/convinteger_bias_test.cpp +++ b/test/onnx/parse/convinteger_bias_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,24 +28,20 @@ TEST_CASE(convinteger_bias_test) { migraphx::program p; auto* mm = p.get_main_module(); - auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}}); - auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}}); + auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 32, 32}}); + auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 5, 5}}); auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}}); mm->add_literal(migraphx::literal{migraphx::shape{data->get_shape().type(), {1}, {0}}, {0}}); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weights); auto bcast_data_bias = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", weights->get_shape().lens()}}), - data_bias); + migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias); auto quant2 = mm->add_instruction(migraphx::make_op("quant_convolution"), bcast_data_bias, weights); - auto bcast_quant2 = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant2); - - mm->add_instruction(migraphx::make_op("sub"), quant, bcast_quant2); + mm->add_instruction(migraphx::make_op("sub"), quant, quant2); auto prog = optimize_onnx("convinteger_bias_test.onnx"); EXPECT(p == prog); diff --git a/test/onnx/parse/convinteger_dual_bias_test.cpp b/test/onnx/parse/convinteger_dual_bias_test.cpp index 97445ff6ab9..fa554356093 100644 --- a/test/onnx/parse/convinteger_dual_bias_test.cpp +++ b/test/onnx/parse/convinteger_dual_bias_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,41 +28,37 @@ TEST_CASE(convinteger_dual_bias_test) { migraphx::program p; auto* mm = p.get_main_module(); - auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 5, 5}}); - auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 2, 2}}); + auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 10, 10}}); + auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 3, 3}}); auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}}); auto weight_bias = mm->add_parameter("3", {migraphx::shape::int8_type, {1}, {1}}); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weight); auto mbcast_data_bias = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), data_bias); + migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias); - auto quant_db_w = + auto quant_mb_w = mm->add_instruction(migraphx::make_op("quant_convolution"), mbcast_data_bias, weight); - auto quant_mb_w = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_db_w); - quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_mb_w); auto mbcast_weight_bias = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), weight_bias); + migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), + weight_bias); - auto quant_d_wb = + auto quant_md_wb = mm->add_instruction(migraphx::make_op("quant_convolution"), data, mbcast_weight_bias); - auto quant_md_wb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_d_wb); - quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_md_wb); - auto bcast_data_bias = mm->add_instruction( - migraphx::make_op("broadcast", {{"out_lens", data->get_shape().lens()}}), data_bias); - auto bcast_weight_bias = mm->add_instruction( - migraphx::make_op("broadcast", {{"out_lens", weight->get_shape().lens()}}), weight_bias); + mbcast_data_bias = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias); + mbcast_weight_bias = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), + weight_bias); auto bias_quant = mm->add_instruction( - migraphx::make_op("quant_convolution"), bcast_data_bias, bcast_weight_bias); + migraphx::make_op("quant_convolution"), mbcast_data_bias, mbcast_weight_bias); mm->add_instruction(migraphx::make_op("add"), quant, bias_quant); diff --git a/test/onnx/verify/quant_convolution_dual_bias_test.cpp b/test/onnx/verify/quant_convolution_dual_bias_test.cpp index 27a9518c4c1..6fa2f409241 100644 --- a/test/onnx/verify/quant_convolution_dual_bias_test.cpp +++ b/test/onnx/verify/quant_convolution_dual_bias_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,11 +24,15 @@ #include #include +#include +#include +#include #include TEST_CASE(quant_convolution_dual_zero_bias_test) { - migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx"); + // TODO: use other dual_bias test, verify with other framework once convinteger supported + migraphx::program p = read_onnx("convinteger_dual_bias_simple_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}}; @@ -82,7 +86,7 @@ TEST_CASE(quant_convolution_dual_zero_bias_test) TEST_CASE(quant_convolution_dual_non_zero_bias_test) { // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul - migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx"); + migraphx::program p = read_onnx("convinteger_dual_bias_simple_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}}; @@ -113,22 +117,40 @@ TEST_CASE(quant_convolution_dual_non_zero_bias_test) std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-6088, - 6248, - -6472, - 6632, - 6664, - -8264, - 8520, - -8713, - -3788, - -1446, - 1488, - -1586, - -712, - 745, - -914, - 1019}; + // create the following program to compare: + // conv(x-x_bias,w-w_bias) + // where datatypes for x,w,x_bias,w_bias are int32 + migraphx::program p2; + migraphx::module* mm = p2.get_main_module(); - EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + migraphx::shape a_i32{migraphx::shape::int32_type, {1, 3, 5, 5}}; + migraphx::shape b_i32{migraphx::shape::int32_type, {1, 3, 2, 2}}; + + migraphx::shape bias_i32{migraphx::shape::int32_type, {1}, {1}}; + auto x = mm->add_parameter("0", a_i32); + auto weights = mm->add_parameter("1", b_i32); + auto x_bias = mm->add_parameter("2", bias_i32); + auto weights_bias = mm->add_parameter("3", bias_i32); + + auto sub_input = add_common_op(*mm, migraphx::make_op("sub"), {x, x_bias}); + auto sub_weights = add_common_op(*mm, migraphx::make_op("sub"), {weights, weights_bias}); + mm->add_instruction(migraphx::make_op("convolution"), sub_input, sub_weights); + + std::vector data_a_i32(data_a.begin(), data_a.end()); + std::vector data_b_i32(data_b.begin(), data_b.end()); + std::vector data_a_bias_i32 = {10}; + std::vector data_b_bias_i32 = {-2}; + + migraphx::parameter_map pp2; + pp2["0"] = migraphx::argument(a_i32, data_a_i32.data()); + pp2["1"] = migraphx::argument(b_i32, data_b_i32.data()); + pp2["2"] = migraphx::argument(bias_i32, data_a_bias_i32.data()); + pp2["3"] = migraphx::argument(bias_i32, data_b_bias_i32.data()); + + auto result2 = p2.eval(pp2).back(); + + std::vector result_vector_i32; + result2.visit([&](auto output) { result_vector_i32.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify::verify_rms_range(result_vector, result_vector_i32)); } diff --git a/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp b/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp index 127286001ad..5876fd0d1d4 100644 --- a/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp +++ b/test/onnx/verify/quant_convolution_mismatched_input_dual_bias_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -112,22 +112,40 @@ TEST_CASE(quant_convolution_mismatched_inputs_dual_non_zero_bias_test) std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-6088, - 6248, - -6472, - 6632, - 6664, - -8264, - 8520, - -8713, - -3788, - -1446, - 1488, - -1586, - -712, - 745, - -914, - 1019}; + // create the following program to compare: + // conv(x-x_bias,w-w_bias) + // where datatypes for x,w,x_bias,w_bias are int32 + migraphx::program p2; + migraphx::module* mm = p2.get_main_module(); - EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + migraphx::shape a_i32{migraphx::shape::int32_type, {1, 3, 5, 5}}; + migraphx::shape b_i32{migraphx::shape::int32_type, {1, 3, 2, 2}}; + + migraphx::shape bias_i32{migraphx::shape::int32_type, {1}, {1}}; + auto x = mm->add_parameter("0", a_i32); + auto weights = mm->add_parameter("1", b_i32); + auto x_bias = mm->add_parameter("2", bias_i32); + auto weights_bias = mm->add_parameter("3", bias_i32); + + auto sub_input = add_common_op(*mm, migraphx::make_op("sub"), {x, x_bias}); + auto sub_weights = add_common_op(*mm, migraphx::make_op("sub"), {weights, weights_bias}); + mm->add_instruction(migraphx::make_op("convolution"), sub_input, sub_weights); + + std::vector data_a_i32(data_a.begin(), data_a.end()); + std::vector data_b_i32(data_b.begin(), data_b.end()); + std::vector data_a_bias_i32 = {138}; + std::vector data_b_bias_i32 = {-2}; + + migraphx::parameter_map pp2; + pp2["0"] = migraphx::argument(a_i32, data_a_i32.data()); + pp2["1"] = migraphx::argument(b_i32, data_b_i32.data()); + pp2["2"] = migraphx::argument(bias_i32, data_a_bias_i32.data()); + pp2["3"] = migraphx::argument(bias_i32, data_b_bias_i32.data()); + + auto result2 = p2.eval(pp2).back(); + + std::vector result_vector_i32; + result2.visit([&](auto output) { result_vector_i32.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify::verify_rms_range(result_vector, result_vector_i32)); }