From 783f000955dd994c9cf42b11945594386ee71493 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Wed, 2 Oct 2024 09:50:30 +0100 Subject: [PATCH] Test group API with group_count=1 --- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index a651f9ae3..c697e644a 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -364,58 +364,79 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } class GemmBatchUsmTests - : public ::testing::TestWithParam> {}; + : public ::testing::TestWithParam> {}; TEST_P(GemmBatchUsmTests, RealHalfPrecision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP((test( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) { - EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), - std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) { - EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), - std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP((test( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP( - (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(( - test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP( (test, std::complex, std::complex, std::complex>( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP( - (test, std::complex, std::complex, - std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP(( + test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } +class GemmBatchGroupNamePrint { +public: + std::string operator()( + testing::TestParamInfo> params) const { + std::string base_name = LayoutDeviceNamePrint()( + { { std::get<0>(params.param), std::get<1>(params.param) }, 0 }); + std::string group_name = "GroupCount_" + std::to_string(std::get<2>(params.param)); + std::string info_name = base_name + "_" + group_name; + return info_name; + } +}; + INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), testing::Values(oneapi::mkl::layout::col_major, - oneapi::mkl::layout::row_major)), - ::LayoutDeviceNamePrint()); + oneapi::mkl::layout::row_major), + testing::Values(1, 5)), + ::GemmBatchGroupNamePrint()); } // anonymous namespace