From 2bb30f5ab92b75813ac920dd4d08d9eaec16e59f Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 07:56:35 -0700 Subject: [PATCH] applied feedback --- .../device/include/rng_device_test_common.hpp | 16 ++-- .../unit_tests/rng/device/moments/moments.cpp | 90 +++++++++++-------- 2 files changed, 59 insertions(+), 47 deletions(-) diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index aa2a54e09..33533255e 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -191,8 +191,8 @@ struct statistics_device> { template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -210,8 +210,8 @@ struct statistics_device template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -229,8 +229,8 @@ struct statistics_device struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -248,8 +248,8 @@ struct statistics_device template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 5fc14b7fd..8e5e55239 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -1055,9 +1055,7 @@ class Philox4x32x10BetaCjaDeviceMomentsTests class Philox4x32x10BetaCjaAccDeviceMomentsTests : public ::testing::TestWithParam {}; -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::beta< @@ -1074,26 +1072,29 @@ TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::beta_method::cja>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test< moments_test, @@ -1113,24 +1114,29 @@ TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite, @@ -1147,9 +1153,7 @@ class Philox4x32x10GammaMarsagliaDeviceMomentsTests class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests : public ::testing::TestWithParam {}; -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::gamma< @@ -1166,26 +1170,29 @@ TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test< moments_test, @@ -1205,24 +1212,29 @@ TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaDeviceMomentsTestsSuite, @@ -1257,17 +1269,17 @@ TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) { TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) { rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); }