diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index 5e373e2cf..1bc954e49 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -352,6 +352,22 @@ struct statistics_device> { } }; +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::math::rng::device::geometric& distr) { + double tM, tD, tQ; + double p = static_cast(distr.p()); + + tM = (1.0 - p) / p; + tD = (1.0 - p) / (p * p); + tQ = (1.0 - p) * (p * p - 9.0 * p + 9.0) / (p * p * p * p); + + return compare_moments(r, tM, tD, tQ); + } +}; + template struct statistics_device> { template diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index a191b67df..3183f0855 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -1416,4 +1416,99 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite, Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Philox4x32x10GeometricIcdfDeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, Integer16Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedInteger16Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GeometricIcdfDeviceMomentsTestsSuite, + Philox4x32x10GeometricIcdfDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + } // anonymous namespace