Skip to content

Commit

Permalink
applied feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyfe1 committed Sep 30, 2024
1 parent 21dd603 commit 2bb30f5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 47 deletions.
16 changes: 8 additions & 8 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ struct statistics_device<oneapi::mkl::rng::device::uniform<Fp, Method>> {
template <typename Method>
struct statistics_device<oneapi::mkl::rng::device::uniform<std::int32_t, Method>> {
template <typename AllocType>
bool check(const std::vector<int32_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<int32_t, Method>& distr) {
bool check(const std::vector<std::int32_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<std::int32_t, Method>& distr) {
double tM, tD, tQ;
double a = distr.a();
double b = distr.b();
Expand All @@ -210,8 +210,8 @@ struct statistics_device<oneapi::mkl::rng::device::uniform<std::int32_t, Method>
template <typename Method>
struct statistics_device<oneapi::mkl::rng::device::uniform<std::uint32_t, Method>> {
template <typename AllocType>
bool check(const std::vector<uint32_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<uint32_t, Method>& distr) {
bool check(const std::vector<std::uint32_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<std::uint32_t, Method>& distr) {
double tM, tD, tQ;
double a = distr.a();
double b = distr.b();
Expand All @@ -229,8 +229,8 @@ struct statistics_device<oneapi::mkl::rng::device::uniform<std::uint32_t, Method
template <typename Method>
struct statistics_device<oneapi::mkl::rng::device::uniform<std::int64_t, Method>> {
template <typename AllocType>
bool check(const std::vector<int64_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<int64_t, Method>& distr) {
bool check(const std::vector<std::int64_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<std::int64_t, Method>& distr) {
double tM, tD, tQ;
double a = distr.a();
double b = distr.b();
Expand All @@ -248,8 +248,8 @@ struct statistics_device<oneapi::mkl::rng::device::uniform<std::int64_t, Method>
template <typename Method>
struct statistics_device<oneapi::mkl::rng::device::uniform<std::uint64_t, Method>> {
template <typename AllocType>
bool check(const std::vector<uint64_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<uint64_t, Method>& distr) {
bool check(const std::vector<std::uint64_t, AllocType>& r,
const oneapi::mkl::rng::device::uniform<std::uint64_t, Method>& distr) {
double tM, tD, tQ;
double a = distr.a();
double b = distr.b();
Expand Down
90 changes: 51 additions & 39 deletions tests/unit_tests/rng/device/moments/moments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,7 @@ class Philox4x32x10BetaCjaDeviceMomentsTests
class Philox4x32x10BetaCjaAccDeviceMomentsTests
: public ::testing::TestWithParam<sycl::device*> {};

// implementation uses double precision for accuracy
TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) {

rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::beta<
Expand All @@ -1074,26 +1072,29 @@ TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) {
float, oneapi::mkl::rng::device::beta_method::cja>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja>>>
test4;
EXPECT_TRUEORSKIP((test4(GetParam())));
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<4>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja>>>
test5;
EXPECT_TRUEORSKIP((test5(GetParam())));
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<16>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja>>>
test6;
EXPECT_TRUEORSKIP((test6(GetParam())));
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

// implementation uses double precision for accuracy
TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) {

rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
Expand All @@ -1113,24 +1114,29 @@ TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) {
float, oneapi::mkl::rng::device::beta_method::cja_accurate>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja_accurate>>>
test4;
EXPECT_TRUEORSKIP((test4(GetParam())));
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<4>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja_accurate>>>
test5;
EXPECT_TRUEORSKIP((test5(GetParam())));
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<16>,
oneapi::mkl::rng::device::beta<
double, oneapi::mkl::rng::device::beta_method::cja_accurate>>>
test6;
EXPECT_TRUEORSKIP((test6(GetParam())));
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite,
Expand All @@ -1147,9 +1153,7 @@ class Philox4x32x10GammaMarsagliaDeviceMomentsTests
class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests
: public ::testing::TestWithParam<sycl::device*> {};

// implementation uses double precision for accuracy
TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) {

rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::gamma<
Expand All @@ -1166,26 +1170,29 @@ TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) {
float, oneapi::mkl::rng::device::gamma_method::marsaglia>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia>>>
test4;
EXPECT_TRUEORSKIP((test4(GetParam())));
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<4>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia>>>
test5;
EXPECT_TRUEORSKIP((test5(GetParam())));
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<16>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia>>>
test6;
EXPECT_TRUEORSKIP((test6(GetParam())));
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

// implementation uses double precision for accuracy
TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealSinglePrecision) {

rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
Expand All @@ -1205,24 +1212,29 @@ TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) {
float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>>
test4;
EXPECT_TRUEORSKIP((test4(GetParam())));
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<4>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>>
test5;
EXPECT_TRUEORSKIP((test5(GetParam())));
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::mkl::rng::device::philox4x32x10<16>,
oneapi::mkl::rng::device::gamma<
double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>>
test6;
EXPECT_TRUEORSKIP((test6(GetParam())));
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaDeviceMomentsTestsSuite,
Expand Down Expand Up @@ -1257,17 +1269,17 @@ TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) {
TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) {
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<1>,
oneapi::mkl::rng::device::poisson<
uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<4>,
oneapi::mkl::rng::device::poisson<
uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<moments_test<oneapi::mkl::rng::device::philox4x32x10<16>,
oneapi::mkl::rng::device::poisson<
uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}
Expand Down

0 comments on commit 2bb30f5

Please sign in to comment.