Skip to content

Commit

Permalink
revert changes for host api
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyfe1 committed Sep 30, 2024
1 parent b59c371 commit b897436
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions include/oneapi/mkl/rng/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,17 @@ template <typename Type = float, typename Method = uniform_method::by_default>
class uniform {
public:
static_assert(std::is_same<Method, uniform_method::standard>::value ||
std::is_same<Method, uniform_method::accurate>::value,
(std::is_same<Method, uniform_method::accurate>::value &&
!std::is_same<Type, std::int32_t>::value),
"rng uniform distribution method is incorrect");

static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value ||
std::is_same<Type, std::int32_t>::value ||
std::is_same<Type, std::uint32_t>::value,
static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value,
"rng uniform distribution type is not supported");

using method_type = Method;
using result_type = Type;

uniform()
: uniform(static_cast<Type>(0.0f),
std::is_integral<Type>::value
? (std::is_same<Method, uniform_method::standard>::value
? (1 << 23)
: (std::numeric_limits<Type>::max)())
: static_cast<Type>(1.0f)) {}
uniform() : uniform(static_cast<Type>(0.0f), static_cast<Type>(1.0f)) {}

explicit uniform(Type a, Type b) : a_(a), b_(b) {
if (a >= b) {
Expand All @@ -100,6 +93,34 @@ class uniform {
Type b_;
};

template <typename Method>
class uniform<std::int32_t, Method> {
public:
using method_type = Method;
using result_type = std::int32_t;

uniform() : uniform(0, std::numeric_limits<std::int32_t>::max()) {}

explicit uniform(std::int32_t a, std::int32_t b) : a_(a), b_(b) {
if (a >= b) {
throw oneapi::mkl::invalid_argument("rng", "uniform",
"parameters are incorrect, a >= b");
}
}

std::int32_t a() const {
return a_;
}

std::int32_t b() const {
return b_;
}

private:
std::int32_t a_;
std::int32_t b_;
};

// Class template oneapi::mkl::rng::gaussian
//
// Represents continuous normal random number distribution
Expand Down

0 comments on commit b897436

Please sign in to comment.