From 9c27c6f0cad469e208377507ebfaf2be861e834b Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 21 Aug 2024 15:27:09 +0400 Subject: [PATCH] [Snippets][CPU] Added test --- .../snippets/src/pass/mha_tokenization.cpp | 1 - .../shared_tests_instances/snippets/mha.cpp | 60 +++++++++++++++++- .../plugin/shared/include/snippets/mha.hpp | 2 + .../plugin/shared/src/snippets/mha.cpp | 28 +++------ .../include/subgraph_mha.hpp | 10 +-- .../ov_snippets_models/src/subgraph_mha.cpp | 63 ++++++++++++------- 6 files changed, 115 insertions(+), 49 deletions(-) diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index c07adc2ea55502..c42eb08b82bd4a 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -592,7 +592,6 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed); - std::cout << "tokenized\n"; return true; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 19f1b230cd2c58..d676fd93ff711b 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -55,6 +55,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D, ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({false, true}), + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -82,7 +83,40 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMHA_4D, ::testing::Combine(::testing::ValuesIn(inputShapes_4D_dynamic), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), - ::testing::ValuesIn({false}), + ::testing::Values(false), + ::testing::Values(true), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); + +std::vector> inputShapes_4D_dynamic_with_mul{ + { + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 70, 3, 19}, {1, 128, 3, 64}, {1, 68, 6, 87}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 49, 1, 19}, {1, 128, 1, 64}, {2, 13, 6, 87}}}, + {PartialShape{1}, {{1}, {1}, {1}, {1} }}, + {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {1, 1, 70, 49}, {2, 1, 128, 128}, {1, 1, 68, 13}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 49, 3, 19}, {1, 128, 3, 64}, {2, 13, 6, 87}}}, + }, + { + {PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, + {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, + {PartialShape{1}, {{1}, {1}, {1}, {1}, {1}}}, + {PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, + {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, + } +}; + + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMHA_4D_Wil_Dynamic_Mul, + MHA, + ::testing::Combine(::testing::ValuesIn(inputShapes_4D_dynamic_with_mul), + ::testing::ValuesIn(precision_f32(5)), + ::testing::Values(ov::element::f32), + ::testing::Values(true), + ::testing::Values(false), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -97,6 +131,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_3D, ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({false, true}), + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(5), // [122706]: Subgraph + 4 Transpose ::testing::Values(2), // decomposed Transpose + MHA @@ -113,6 +148,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(true), + ::testing::Values(true), ::testing::Values(4), // 4 Threads ::testing::Values(6), // Subgraph + 4 Reshapes on inputs and 1 Reshape on output ::testing::Values(1), @@ -128,6 +164,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(true), + ::testing::Values(true), ::testing::Values(4), // 4 Threads ::testing::Values(10), // Subgraph + 4 Reshapes on inputs and 1 Reshape on output + 4 Transposes ::testing::Values(1), // MHA @@ -169,6 +206,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(false), + ::testing::Values(true), ::testing::Values(4), // 4 Threads ::testing::Values(1), ::testing::Values(1), @@ -198,6 +236,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(false), + ::testing::Values(true), ::testing::Values(4), // 4 Threads ::testing::Values(5), // Subgraph + 4 Transpose ::testing::Values(2), // MHA + one of the transposes is executed via Subgraph (because callback is disabled) @@ -211,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D, ::testing::ValuesIn(precision_bf16(4)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({false, true}), + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(7), // MHA + 5 Converts + 1 Transpose on output ::testing::Values(6), // MHA + 5 Converts on inputs and output @@ -224,6 +264,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::bf16), ::testing::ValuesIn({false}), + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(7), ::testing::Values(6), @@ -239,6 +280,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({false}), // Need to support True for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -262,6 +304,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(6)), ::testing::Values(ov::element::f32), ::testing::Values(false), // Need to support True for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(2), // Less + MHA ::testing::Values(2), @@ -283,6 +326,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::Values(true), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -297,6 +341,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -311,6 +356,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -340,6 +386,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -354,6 +401,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_bf16(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(5), // MHA + 4 extra Converts on inputs and output ::testing::Values(5), // MHA + 4 extra Converts on inputs and output @@ -368,6 +416,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_bf16(3)), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(5), // MHA + 4 extra Converts on inputs and output ::testing::Values(5), // MHA + 4 extra Converts on inputs and output @@ -382,6 +431,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::bf16), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(5), // MHA + 4 extra Converts on inputs and output ::testing::Values(5), // MHA + 4 extra Converts on inputs and output @@ -396,6 +446,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::bf16), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(5), // MHA + 4 extra Converts on inputs and output ::testing::Values(5), // MHA + 4 extra Converts on inputs and output @@ -411,6 +462,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul ::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul @@ -426,6 +478,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(9), // FQx2 on inputs + MHA + Transpose on output + 4 Reshapes + Deq Mul ::testing::Values(4), // FQx2 on inputs + MHA + Deq Mul @@ -439,6 +492,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul_4D, ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(3), // MHA + Transpose on output + Deq Mul ::testing::Values(2), // MHA + Deq Mul @@ -456,6 +510,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(7), // Transposex2 + Subgraphsx5 ::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs @@ -472,6 +527,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(1), ::testing::Values(1), @@ -498,6 +554,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(2), ::testing::Values(1), @@ -520,6 +577,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // False is not supported for graph builder in tests + ::testing::Values(true), ::testing::Values(MHA::default_thread_count), ::testing::Values(3), // Extracted Add + Extracted Reshape + MHA ::testing::Values(2), // Extracted Add + MHA diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index f73dba5d4ad5ce..19dc94e821d6c1 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -15,6 +15,7 @@ typedef std::tuple, // Input shapes std::vector, // Input Element types ov::element::Type, // Inference precision bool, // With Multiply + bool, // True if second input of Mul is Const size_t, // Thread count size_t, // Expected num nodes size_t, // Expected num subgraphs @@ -38,6 +39,7 @@ class MHA : public testing::WithParamInterface, virtual std::shared_ptr get_subgraph(); bool m_with_mul = false; + bool m_is_mul_const = true; size_t m_thread_count; std::vector m_input_types; }; diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 9b5cbe2bafaf43..1a76f1ac5b4f33 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -18,20 +18,13 @@ std::string MHA::getTestCaseName(testing::TestParamInfo input_shapes; std::vector elem_types; ov::element::Type prc; - bool with_mul; + bool with_mul, mul_is_const; size_t thread_count; std::string target_device; size_t num_nodes, num_subgraphs; ov::AnyMap additional_config; - std::tie(input_shapes, - elem_types, - prc, - with_mul, - thread_count, - num_nodes, - num_subgraphs, - target_device, - additional_config) = obj.param; + std::tie(input_shapes, elem_types, prc, with_mul, mul_is_const, + thread_count, num_nodes, num_subgraphs, target_device, additional_config) = obj.param; std::ostringstream result; for (size_t i = 0; i < input_shapes.size(); i++) @@ -39,6 +32,8 @@ std::string MHA::getTestCaseName(testing::TestParamInfo input_shapes; ov::element::Type prc; ov::AnyMap additional_config; - std::tie(input_shapes, - m_input_types, - prc, - m_with_mul, - m_thread_count, - ref_num_nodes, - ref_num_subgraphs, - targetDevice, - additional_config) = this->GetParam(); + std::tie(input_shapes, m_input_types, prc, m_with_mul, m_is_mul_const, m_thread_count, + ref_num_nodes, ref_num_subgraphs, targetDevice, additional_config) = this->GetParam(); init_input_shapes(input_shapes); const auto subgraph_model = get_subgraph(); @@ -109,7 +97,7 @@ void MHA::generate_inputs(const std::vector& targetInputStaticShapes) std::shared_ptr MHA::get_subgraph() { bool is_with_reshape = std::all_of(inputDynamicShapes.begin(), inputDynamicShapes.end(), [](const PartialShape& ps){ return ps.is_static(); }); - return std::make_shared(inputDynamicShapes, m_input_types, m_with_mul, is_with_reshape); + return std::make_shared(inputDynamicShapes, m_input_types, m_with_mul, is_with_reshape, m_is_mul_const); } void MHASelect::generate_inputs(const std::vector& targetInputStaticShapes) { diff --git a/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp b/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp index 60af12e27e5f48..91dbfc7e5ea3d5 100644 --- a/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp +++ b/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp @@ -43,10 +43,11 @@ namespace snippets { class MHAFunction : public SnippetsFunctionBase { public: explicit MHAFunction(const std::vector& inputShapes, const std::vector& precisions, - bool with_mul = true, bool with_reshape = true) - : SnippetsFunctionBase(inputShapes), with_mul(with_mul), with_reshape(with_reshape), precisions(precisions) { - OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); - OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions"); + bool with_mul = true, bool with_reshape = true, bool is_mul_const = true) + : SnippetsFunctionBase(inputShapes), with_mul(with_mul), with_reshape(with_reshape), is_mul_const(is_mul_const), precisions(precisions) { + const size_t count = with_mul && !is_mul_const ? 5 : 4; + OPENVINO_ASSERT(input_shapes.size() == count, "Got invalid number of input shapes"); + OPENVINO_ASSERT(precisions.size() == count, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; @@ -54,6 +55,7 @@ class MHAFunction : public SnippetsFunctionBase { const bool with_mul = true; const bool with_reshape = true; + const bool is_mul_const = true; const std::vector precisions; }; diff --git a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp index 3ca78918f5e925..26d125c0d33dd1 100644 --- a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp +++ b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp @@ -50,10 +50,11 @@ std::vector get_decomposed_order_after_split_m(size_t rank) { } // namespace std::shared_ptr MHAFunction::initOriginal() const { + const auto shift = !is_mul_const; auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); - auto addParam = std::make_shared(precisions[2], input_shapes[2]); - auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); + auto addParam = std::make_shared(precisions[shift + 2], input_shapes[shift + 2]); + auto transpose2Param = std::make_shared(precisions[shift + 3], input_shapes[shift + 3]); ov::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; const auto rank = input_shapes[0].size(); @@ -69,12 +70,18 @@ std::shared_ptr MHAFunction::initOriginal() const { const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { - ov::Shape shape(rank, 1); - if (transpose1->get_output_partial_shape(0).is_static()) { - shape[rank - 3] = transpose1->get_output_shape(0)[rank - 3]; + if (is_mul_const) { + ov::Shape shape(rank, 1); + if (transpose1->get_output_partial_shape(0).is_static()) { + shape[rank - 3] = transpose1->get_output_shape(0)[rank - 3]; + } + const auto mulConst = ov::test::utils::make_constant(precisions[1], shape); + matmul_parent1 = std::make_shared(transpose1, mulConst); + } else { + const auto mulParam = std::make_shared(precisions[2], input_shapes[2]); + matmul_parent1 = std::make_shared(transpose1, mulParam); + ngraphParam.insert(ngraphParam.cbegin() + 2, mulParam); } - const auto mulConst = ov::test::utils::make_constant(precisions[1], shape); - matmul_parent1 = std::make_shared(transpose1, mulConst); } const auto matMul0 = std::make_shared(transpose0, matmul_parent1); const auto add = std::make_shared(matMul0, addParam); @@ -105,17 +112,18 @@ std::shared_ptr MHAFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } std::shared_ptr MHAFunction::initReference() const { + const auto shift = !is_mul_const; auto data0 = std::make_shared(precisions[0], input_shapes[0]); auto data1 = std::make_shared(precisions[1], input_shapes[1]); - auto data2 = std::make_shared(precisions[2], input_shapes[2]); - auto data3 = std::make_shared(precisions[3], input_shapes[3]); + auto data2 = std::make_shared(precisions[shift + 2], input_shapes[shift + 2]); + auto data3 = std::make_shared(precisions[shift + 3], input_shapes[shift + 3]); ov::ParameterVector ngraphParams = {data0, data1, data2, data3}; NodeVector subgraph_inputs = {data0, data1, data2, data3}; auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); - auto addParam = std::make_shared(precisions[2], input_shapes[2]); - auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); + auto addParam = std::make_shared(precisions[shift + 2], input_shapes[shift + 2]); + auto transpose2Param = std::make_shared(precisions[shift + 3], input_shapes[shift + 3]); ov::ParameterVector subgraph_params = {transpose0Param, transpose1Param, addParam, transpose2Param}; @@ -132,19 +140,28 @@ std::shared_ptr MHAFunction::initReference() const { const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { - ov::Shape shape(rank, 1); - if (transpose1->get_output_partial_shape(0).is_static()) { - shape[rank - 3] = transpose1->get_output_shape(0)[rank - 3]; - } - const auto mulConst = ov::test::utils::make_constant(precisions[1], shape); - - if (ov::shape_size(shape) > 1) { - const auto mulParam = std::make_shared(precisions[1], mulConst->get_shape()); - matmul_parent1 = std::make_shared(transpose1, mulParam); - subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param}; - subgraph_inputs = {data0, data1, mulConst, data2, data3}; + if (is_mul_const) { + ov::Shape shape(rank, 1); + if (transpose1->get_output_partial_shape(0).is_static()) { + shape[rank - 3] = transpose1->get_output_shape(0)[rank - 3]; + } + const auto mulConst = ov::test::utils::make_constant(precisions[1], shape); + + if (ov::shape_size(shape) > 1) { + const auto mulParam = std::make_shared(precisions[1], mulConst->get_shape()); + matmul_parent1 = std::make_shared(transpose1, mulParam); + subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param}; + subgraph_inputs = {data0, data1, mulConst, data2, data3}; + } else { + matmul_parent1 = std::make_shared(transpose1, mulConst); + } } else { - matmul_parent1 = std::make_shared(transpose1, mulConst); + auto dataMul = std::make_shared(precisions[2], input_shapes[2]); + auto paramMul = std::make_shared(precisions[2], input_shapes[2]); + ngraphParams.insert(ngraphParams.begin() + 2, dataMul); + subgraph_inputs.insert(subgraph_inputs.begin() + 2, dataMul); + subgraph_params.insert(subgraph_params.begin() + 2, paramMul); + matmul_parent1 = std::make_shared(transpose1, paramMul); } }