Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RNG] Add geometric distribution to Device API #622

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/oneapi/math/rng/device/detail/distribution_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class poisson;
template <typename IntType = std::uint32_t, typename Method = bernoulli_method::by_default>
class bernoulli;

template <typename IntType = std::uint32_t, typename Method = geometric_method::by_default>
class geometric;

} // namespace oneapi::math::rng::device

#include "oneapi/math/rng/device/detail/uniform_impl.hpp"
Expand All @@ -75,6 +78,7 @@ class bernoulli;
#include "oneapi/math/rng/device/detail/exponential_impl.hpp"
#include "oneapi/math/rng/device/detail/poisson_impl.hpp"
#include "oneapi/math/rng/device/detail/bernoulli_impl.hpp"
#include "oneapi/math/rng/device/detail/geometric_impl.hpp"
#include "oneapi/math/rng/device/detail/beta_impl.hpp"
#include "oneapi/math/rng/device/detail/gamma_impl.hpp"

Expand Down
99 changes: 99 additions & 0 deletions include/oneapi/math/rng/device/detail/geometric_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
#define ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_

namespace oneapi::math::rng::device::detail {

template <typename IntType, typename Method>
class distribution_base<oneapi::math::rng::device::geometric<IntType, Method>> {
public:
struct param_type {
param_type(float p) : p_(p) {}
float p_;
};

distribution_base(float p) : p_(p) {
#ifndef __SYCL_DEVICE_ONLY__
if ((p > 1.0f) || (p < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
}

float p() const {
return p_;
}

param_type param() const {
return param_type(p_);
}

void param(const param_type& pt) {
#ifndef __SYCL_DEVICE_ONLY__
if ((pt.p_ > 1.0f) || (pt.p_ < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
p_ = pt.p_;
}

protected:
template <typename EngineType>
auto generate(EngineType& engine) ->
typename std::conditional<EngineType::vec_size == 1, IntType,
sycl::vec<IntType, EngineType::vec_size>>::type {
using FpType = typename std::conditional<std::is_same_v<IntType, std::uint64_t> ||
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

auto uni_res = engine.generate(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
if constexpr (EngineType::vec_size == 1) {
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}
else {
sycl::vec<IntType, EngineType::vec_size> vec_out;
for (int i = 0; i < EngineType::vec_size; i++) {
vec_out[i] = static_cast<IntType>(sycl::floor(ln_wrapper(uni_res[i]) * inv_ln));
}
return vec_out;
}
}

template <typename EngineType>
IntType generate_single(EngineType& engine) {
using FpType = typename std::conditional<std::is_same_v<IntType, std::uint64_t> ||
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

FpType uni_res = engine.generate_single(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}

float p_;
};

} // namespace oneapi::math::rng::device::detail

#endif // ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
58 changes: 58 additions & 0 deletions include/oneapi/math/rng/device/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,64 @@ class bernoulli : detail::distribution_base<bernoulli<IntType, Method>> {
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

// Class template oneapi::math::rng::device::geometric
//
// Represents discrete geometric random number distribution
//
// Supported types:
// std::uint32_t
// std::int32_t
// std::uint64_t
// std::int64_t
//
// Supported methods:
// oneapi::math::rng::geometric_method::icdf;
//
// Input arguments:
// p - success probablity of a trial. 0.5 by default
//
template <typename IntType, typename Method>
class geometric : detail::distribution_base<geometric<IntType, Method>> {
public:
static_assert(std::is_same<Method, geometric_method::icdf>::value,
"oneMath: rng/geometric: method is incorrect");

static_assert(std::is_same<IntType, std::int32_t>::value ||
std::is_same<IntType, std::uint32_t>::value ||
std::is_same<IntType, std::int64_t>::value ||
std::is_same<IntType, std::uint64_t>::value,
"oneMath: rng/geometric: type is not supported");

using method_type = Method;
using result_type = IntType;
using param_type = typename detail::distribution_base<geometric<IntType, Method>>::param_type;

geometric() : detail::distribution_base<geometric<IntType, Method>>(0.5f) {}

explicit geometric(float p) : detail::distribution_base<geometric<IntType, Method>>(p) {}
explicit geometric(const param_type& pt)
: detail::distribution_base<geometric<IntType, Method>>(pt.p_) {}

float p() const {
return detail::distribution_base<geometric<IntType, Method>>::p();
}

param_type param() const {
return detail::distribution_base<geometric<IntType, Method>>::param();
}

void param(const param_type& pt) {
detail::distribution_base<geometric<IntType, Method>>::param(pt);
}

template <typename Distr, typename Engine>
friend auto generate(Distr& distr, Engine& engine) ->
typename std::conditional<Engine::vec_size == 1, typename Distr::result_type,
sycl::vec<typename Distr::result_type, Engine::vec_size>>::type;
template <typename Distr, typename Engine>
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

} // namespace oneapi::math::rng::device

#endif // ONEMATH_RNG_DEVICE_DISTRIBUTIONS_HPP_
5 changes: 5 additions & 0 deletions include/oneapi/math/rng/device/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ struct icdf {};
using by_default = icdf;
} // namespace bernoulli_method

namespace geometric_method {
struct icdf {};
using by_default = icdf;
} // namespace geometric_method

namespace beta_method {
struct cja {};
struct cja_accurate {};
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,22 @@ struct statistics_device<oneapi::math::rng::device::bernoulli<Fp, Method>> {
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::geometric<Fp, Method>> {
template <typename AllocType>
bool check(const std::vector<Fp, AllocType>& r,
const oneapi::math::rng::device::geometric<Fp, Method>& distr) {
double tM, tD, tQ;
double p = static_cast<double>(distr.p());

tM = (1.0 - p) / p;
tD = (1.0 - p) / (p * p);
tQ = (1.0 - p) * (p * p - 9.0 * p + 9.0) / (p * p * p * p);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::beta<Fp, Method>> {
template <typename AllocType>
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/rng/device/moments/moments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,99 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite,
Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

class Philox4x32x10GeometricIcdfDeviceMomentsTests
: public ::testing::TestWithParam<sycl::device*> {};

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, IntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedIntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, Integer64Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedInteger64Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

INSTANTIATE_TEST_SUITE_P(Philox4x32x10GeometricIcdfDeviceMomentsTestsSuite,
Philox4x32x10GeometricIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

} // anonymous namespace
Loading