Skip to content

Commit

Permalink
Test group API with group_count=1
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Oct 2, 2024
1 parent e0e23d0 commit 783f000
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions tests/unit_tests/blas/batch/gemm_batch_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,58 +364,79 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) {
}

class GemmBatchUsmTests
: public ::testing::TestWithParam<std::tuple<sycl::device *, oneapi::mkl::layout>> {};
: public ::testing::TestWithParam<std::tuple<sycl::device *, oneapi::mkl::layout, int>> {};

TEST_P(GemmBatchUsmTests, RealHalfPrecision) {
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, sycl::half, sycl::half>(
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) {
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(std::get<0>(GetParam()),
std::get<1>(GetParam()), 5)));
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(
std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) {
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(std::get<0>(GetParam()),
std::get<1>(GetParam()), 5)));
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(
std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) {
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, std::int32_t, float>(
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, RealSinglePrecision) {
EXPECT_TRUEORSKIP(
(test<float, float, float, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<float, float, float, float>(std::get<0>(GetParam()),
std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP((
test<double, double, double, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((test<double, double, double, double>(std::get<0>(GetParam()),
std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) {
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP(
(test<std::complex<float>, std::complex<float>, std::complex<float>, std::complex<float>>(
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP(
(test<std::complex<double>, std::complex<double>, std::complex<double>,
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
int group_count = std::get<2>(GetParam());
EXPECT_TRUEORSKIP((
test<std::complex<double>, std::complex<double>, std::complex<double>,
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), group_count)));
}

class GemmBatchGroupNamePrint {
public:
std::string operator()(
testing::TestParamInfo<std::tuple<sycl::device *, oneapi::mkl::layout, int>> params) const {
std::string base_name = LayoutDeviceNamePrint()(
{ { std::get<0>(params.param), std::get<1>(params.param) }, 0 });
std::string group_name = "GroupCount_" + std::to_string(std::get<2>(params.param));
std::string info_name = base_name + "_" + group_name;
return info_name;
}
};

INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests,
::testing::Combine(testing::ValuesIn(devices),
testing::Values(oneapi::mkl::layout::col_major,
oneapi::mkl::layout::row_major)),
::LayoutDeviceNamePrint());
oneapi::mkl::layout::row_major),
testing::Values(1, 5)),
::GemmBatchGroupNamePrint());

} // anonymous namespace

0 comments on commit 783f000

Please sign in to comment.