Skip to content

Commit

Permalink
[RNG] Fixed some bugs for device API (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyfe1 authored May 28, 2024
1 parent 90bc218 commit f9983ee
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/rng/device/uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ bool isDoubleSupported(sycl::device my_dev) {
}

// example parameters
constexpr int seed = 777;
constexpr std::uint64_t seed = 777;
constexpr std::size_t n = 1024;
constexpr int n_print = 10;

Expand Down
16 changes: 4 additions & 12 deletions include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,10 @@ static inline void skip_ahead(engine_state<oneapi::mkl::rng::device::mcg31m1<Vec

template <std::int32_t VecSize>
static inline void init(engine_state<oneapi::mkl::rng::device::mcg31m1<VecSize>>& state,
std::uint64_t n, const std::uint32_t* seed_ptr, std::uint64_t offset) {
if (n == 0)
std::uint32_t seed, std::uint64_t offset) {
state.s = custom_mod<std::uint32_t>(seed);
if (state.s == 0)
state.s = 1;
else {
state.s = custom_mod<std::uint32_t>(seed_ptr[0]);
if (state.s == 0)
state.s = 1;
}
skip_ahead(state, offset);
}

Expand Down Expand Up @@ -183,11 +179,7 @@ template <std::int32_t VecSize>
class engine_base<oneapi::mkl::rng::device::mcg31m1<VecSize>> {
protected:
engine_base(std::uint32_t seed, std::uint64_t offset = 0) {
mcg31m1_impl::init(this->state_, 1, &seed, offset);
}

engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t offset = 0) {
mcg31m1_impl::init(this->state_, n, seed, offset);
mcg31m1_impl::init(this->state_, seed, offset);
}

template <typename RealType>
Expand Down
20 changes: 4 additions & 16 deletions include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,8 @@ static inline void skip_ahead(engine_state<oneapi::mkl::rng::device::mcg59<VecSi

template <std::int32_t VecSize>
static inline void init(engine_state<oneapi::mkl::rng::device::mcg59<VecSize>>& state,
std::uint64_t n, std::uint32_t* seed_ptr, std::uint64_t offset) {
if (n < 1) {
state.s = 1;
}
else if (n == 1) {
state.s = static_cast<uint64_t>(seed_ptr[0]) & mcg59_param::m_64;
}
else {
state.s = *(reinterpret_cast<std::uint64_t*>(&seed_ptr[0])) & mcg59_param::m_64;
}
std::uint64_t seed, std::uint64_t offset) {
state.s = seed & mcg59_param::m_64;
if (state.s == 0)
state.s = 1;

Expand Down Expand Up @@ -154,12 +146,8 @@ static inline std::uint64_t generate_single(
template <std::int32_t VecSize>
class engine_base<oneapi::mkl::rng::device::mcg59<VecSize>> {
protected:
engine_base(std::uint32_t seed, std::uint64_t offset = 0) {
mcg59_impl::init(this->state_, 1, &seed, offset);
}

engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t offset = 0) {
mcg59_impl::init(this->state_, n, seed, offset);
engine_base(std::uint64_t seed, std::uint64_t offset = 0) {
mcg59_impl::init(this->state_, seed, offset);
}

template <typename RealType>
Expand Down
8 changes: 1 addition & 7 deletions include/oneapi/mkl/rng/device/engines.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ class mcg31m1 : detail::engine_base<mcg31m1<VecSize>> {
mcg31m1(std::uint32_t seed, std::uint64_t offset = 0)
: detail::engine_base<mcg31m1<VecSize>>(seed, offset) {}

mcg31m1(std::initializer_list<std::uint32_t> seed, std::uint64_t offset = 0)
: detail::engine_base<mcg31m1<VecSize>>(seed.size(), seed.begin(), offset) {}

private:
template <typename Engine>
friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip);
Expand All @@ -157,12 +154,9 @@ class mcg59 : detail::engine_base<mcg59<VecSize>> {

mcg59() : detail::engine_base<mcg59<VecSize>>(default_seed) {}

mcg59(std::uint32_t seed, std::uint64_t offset = 0)
mcg59(std::uint64_t seed, std::uint64_t offset = 0)
: detail::engine_base<mcg59<VecSize>>(seed, offset) {}

mcg59(std::initializer_list<std::uint32_t> seed, std::uint64_t offset = 0)
: detail::engine_base<mcg59<VecSize>>(seed.size(), seed.begin(), offset) {}

private:
template <typename Engine>
friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip);
Expand Down

0 comments on commit f9983ee

Please sign in to comment.