Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iMartyan committed Dec 16, 2024
1 parent 579a87e commit ee70d98
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,22 @@ struct statistics_device<oneapi::math::rng::device::bernoulli<Fp, Method>> {
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::geometric<Fp, Method>> {
template <typename AllocType>
bool check(const std::vector<Fp, AllocType>& r,
const oneapi::math::rng::device::geometric<Fp, Method>& distr) {
double tM, tD, tQ;
double p = static_cast<double>(distr.p());

tM = (1.0 - p) / p;
tD = (1.0 - p) / (p * p);
tQ = (1.0 - p) * (p * p - 9.0 * p + 9.0) / (p * p * p * p);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::beta<Fp, Method>> {
template <typename AllocType>
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/rng/device/moments/moments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,99 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite,
Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

class Philox4x32x10GeometricIcdfDeviceMomentsTests
: public ::testing::TestWithParam<sycl::device*> {};

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, IntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedIntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, Integer16Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedInteger16Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

INSTANTIATE_TEST_SUITE_P(Philox4x32x10GeometricIcdfDeviceMomentsTestsSuite,
Philox4x32x10GeometricIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

} // anonymous namespace

0 comments on commit ee70d98

Please sign in to comment.