From 898cea42585cc0d429dedd51fad33cc454e1c95d Mon Sep 17 00:00:00 2001 From: Andrei Fedorov Date: Tue, 28 May 2024 17:26:23 +0200 Subject: [PATCH] [RNG] Fixed some bugs for device API (#470) --- examples/rng/device/uniform.cpp | 2 +- .../mkl/rng/device/detail/mcg31m1_impl.hpp | 16 ++++----------- .../mkl/rng/device/detail/mcg59_impl.hpp | 20 ++++--------------- include/oneapi/mkl/rng/device/engines.hpp | 8 +------- 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/examples/rng/device/uniform.cpp b/examples/rng/device/uniform.cpp index 65e362e6f..a1c097bba 100644 --- a/examples/rng/device/uniform.cpp +++ b/examples/rng/device/uniform.cpp @@ -46,7 +46,7 @@ bool isDoubleSupported(sycl::device my_dev) { } // example parameters -constexpr int seed = 777; +constexpr std::uint64_t seed = 777; constexpr std::size_t n = 1024; constexpr int n_print = 10; diff --git a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp index 6473bdce2..8f1294ac2 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp @@ -142,14 +142,10 @@ static inline void skip_ahead(engine_state static inline void init(engine_state>& state, - std::uint64_t n, const std::uint32_t* seed_ptr, std::uint64_t offset) { - if (n == 0) + std::uint32_t seed, std::uint64_t offset) { + state.s = custom_mod(seed); + if (state.s == 0) state.s = 1; - else { - state.s = custom_mod(seed_ptr[0]); - if (state.s == 0) - state.s = 1; - } skip_ahead(state, offset); } @@ -183,11 +179,7 @@ template class engine_base> { protected: engine_base(std::uint32_t seed, std::uint64_t offset = 0) { - mcg31m1_impl::init(this->state_, 1, &seed, offset); - } - - engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t offset = 0) { - mcg31m1_impl::init(this->state_, n, seed, offset); + mcg31m1_impl::init(this->state_, seed, offset); } template diff --git a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp index f04e8ac3e..0c2a11b31 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp @@ -111,16 +111,8 @@ static inline void skip_ahead(engine_state static inline void init(engine_state>& state, - std::uint64_t n, std::uint32_t* seed_ptr, std::uint64_t offset) { - if (n < 1) { - state.s = 1; - } - else if (n == 1) { - state.s = static_cast(seed_ptr[0]) & mcg59_param::m_64; - } - else { - state.s = *(reinterpret_cast(&seed_ptr[0])) & mcg59_param::m_64; - } + std::uint64_t seed, std::uint64_t offset) { + state.s = seed & mcg59_param::m_64; if (state.s == 0) state.s = 1; @@ -154,12 +146,8 @@ static inline std::uint64_t generate_single( template class engine_base> { protected: - engine_base(std::uint32_t seed, std::uint64_t offset = 0) { - mcg59_impl::init(this->state_, 1, &seed, offset); - } - - engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t offset = 0) { - mcg59_impl::init(this->state_, n, seed, offset); + engine_base(std::uint64_t seed, std::uint64_t offset = 0) { + mcg59_impl::init(this->state_, seed, offset); } template diff --git a/include/oneapi/mkl/rng/device/engines.hpp b/include/oneapi/mkl/rng/device/engines.hpp index d3ea72022..f1bcfd1b0 100644 --- a/include/oneapi/mkl/rng/device/engines.hpp +++ b/include/oneapi/mkl/rng/device/engines.hpp @@ -130,9 +130,6 @@ class mcg31m1 : detail::engine_base> { mcg31m1(std::uint32_t seed, std::uint64_t offset = 0) : detail::engine_base>(seed, offset) {} - mcg31m1(std::initializer_list seed, std::uint64_t offset = 0) - : detail::engine_base>(seed.size(), seed.begin(), offset) {} - private: template friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); @@ -157,12 +154,9 @@ class mcg59 : detail::engine_base> { mcg59() : detail::engine_base>(default_seed) {} - mcg59(std::uint32_t seed, std::uint64_t offset = 0) + mcg59(std::uint64_t seed, std::uint64_t offset = 0) : detail::engine_base>(seed, offset) {} - mcg59(std::initializer_list seed, std::uint64_t offset = 0) - : detail::engine_base>(seed.size(), seed.begin(), offset) {} - private: template friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip);