Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace raft::random calls to not use deprecated API #1867

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/bench/prims/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ struct KMeansBalanced : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
} else {
raft::random::uniform(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
}
resource::sync_stream(handle, stream);
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/bench/prims/distance/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ struct GramMatrix : public fixture {
A.resize(params.m * params.k, stream);
B.resize(params.k * params.n, stream);
C.resize(params.m * params.n, stream);
raft::random::Rng r(123456ULL);
r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream);
raft::random::RngState rng(123456ULL);
raft::random::uniform(handle, rng, A.data(), params.m * params.k, T(-1.0), T(1.0));
raft::random::uniform(handle, rng, B.data(), params.k * params.n, T(-1.0), T(1.0));
}

~GramMatrix()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct rowNorm : public fixture {
rowNorm(const norm_input<IdxT>& p) : params(p), in(p.rows * p.cols, stream), dots(p.rows, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/normalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct rowNormalize : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_cols_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct reduce_cols_by_key : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_rows_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct reduce_rows_by_key : public fixture {
workspace(p.rows, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct Argmin : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
resource::sync_stream(handle, stream);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ struct Gather : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
raft::random::uniformInt(
handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows);
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1));
}
resource::sync_stream(handle, stream);
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ struct CagraBench : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniformInt(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
} else {
raft::random::uniform(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniform(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
}

// Generate random knn graph

raft::random::uniformInt<IdxT>(
state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1, stream);
handle, state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1);

auto metric = raft::distance::DistanceType::L2Expanded;

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ struct knn : public fixture {
constexpr T kRangeMax = std::is_integral_v<T> ? std::numeric_limits<T>::max() : T(1);
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniformInt(handle, state, vec.data(), n, kRangeMin, kRangeMax);
} else {
raft::random::uniform(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniform(handle, state, vec.data(), n, kRangeMin, kRangeMax);
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ inline void make_rotation_matrix(raft::resources const& handle,
uint32_t n_rows,
uint32_t n_cols,
float* rotation_matrix,
raft::random::Rng rng = raft::random::Rng(7ULL))
raft::random::RngState rng = raft::random::RngState(7ULL))
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols);
Expand All @@ -134,7 +134,7 @@ inline void make_rotation_matrix(raft::resources const& handle,
if (force_random_rotation || !inplace) {
rmm::device_uvector<float> buf(inplace ? 0 : n * n, stream);
float* mat = inplace ? rotation_matrix : buf.data();
rng.normal(mat, n * n, 0.0f, 1.0f, stream);
raft::random::normal(handle, rng, mat, n * n, 0.0f, 1.0f);
linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream);
if (!inplace) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix,
Expand Down
14 changes: 9 additions & 5 deletions cpp/internal/raft_internal/neighbors/refine_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ class RefineHelper {
refined_distances_host(handle),
refined_indices_host(handle)
{
raft::random::Rng r(1234ULL);
raft::random::RngState rng(1234ULL);

dataset = raft::make_device_matrix<DataT, IdxT>(handle_, p.n_rows, p.dim);
queries = raft::make_device_matrix<DataT, IdxT>(handle_, p.n_queries, p.dim);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0), stream_);
r.uniform(queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0), stream_);
raft::random::uniform(
handle, rng, dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0));
raft::random::uniform(
handle, rng, queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0));
} else {
r.uniformInt(dataset.data_handle(), dataset.size(), DataT(1), DataT(20), stream_);
r.uniformInt(queries.data_handle(), queries.size(), DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle, rng, dataset.data_handle(), dataset.size(), DataT(1), DataT(20));
raft::random::uniformInt(
handle, rng, queries.data_handle(), queries.size(), DataT(1), DataT(20));
}

refined_distances = raft::make_device_matrix<DistanceT, IdxT>(handle_, p.n_queries, p.k);
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/distance/gram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram_host.resize(gram.size());
std::fill(gram_host.begin(), gram_host.end(), 0);

raft::random::Rng r(42137ULL);
r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream);
r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream);
raft::random::RngState rng(42137ULL);
raft::random::uniform(handle, rng, x1.data(), x1.size(), math_t(0), math_t(1));
raft::random::uniform(handle, rng, x2.data(), x2.size(), math_t(0), math_t(1));
}

~GramMatrixTest() override {}
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ReduceTest : public ::testing::TestWithParam<ReduceInputs<InType, OutType,
raft::random::RngState r(params.seed);
IdxType rows = params.rows, cols = params.cols;
IdxType len = rows * cols;
gen_uniform(data.data(), r, len, stream);
gen_uniform(handle, data.data(), r, len);

MainLambda main_op;
ReduceLambda reduce_op;
Expand Down
42 changes: 25 additions & 17 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,19 @@ __global__ void GenerateRoundingErrorFreeDataset_kernel(float* const ptr,
ptr[tid] = u32 / resolution;
}

void GenerateRoundingErrorFreeDataset(float* const ptr,
void GenerateRoundingErrorFreeDataset(const raft::resources& handle,
float* const ptr,
const uint32_t n_row,
const uint32_t dim,
raft::random::Rng& rng,
cudaStream_t cuda_stream)
raft::random::RngState& rng)
{
auto cuda_stream = resource::get_cuda_stream(handle);
const uint32_t size = n_row * dim;
const uint32_t block_size = 256;
const uint32_t grid_size = (size + block_size - 1) / block_size;

const uint32_t resolution = 1u << static_cast<unsigned>(std::floor((24 - std::log2(dim)) / 2));
rng.uniformInt(reinterpret_cast<uint32_t*>(ptr), size, 0u, resolution - 1, cuda_stream);
raft::random::uniformInt(handle, rng, reinterpret_cast<uint32_t*>(ptr), size, 0u, resolution - 1);

GenerateRoundingErrorFreeDataset_kernel<<<grid_size, block_size, 0, cuda_stream>>>(
ptr, size, resolution);
Expand Down Expand Up @@ -293,13 +294,16 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
raft::random::normal(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down Expand Up @@ -379,11 +383,12 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
void SetUp() override
{
database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream());
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream());
GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r);
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream());
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
}
handle_.sync_stream();
}
Expand Down Expand Up @@ -643,13 +648,16 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
raft::random::normal(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
14 changes: 9 additions & 5 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,17 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
database.resize(ps.num_db_vecs * ps.dim, stream_);
search_queries.resize(ps.num_queries * ps.dim, stream_);

raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::uniform(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0));
raft::random::uniform(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
14 changes: 9 additions & 5 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,17 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_);
search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_);

raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::uniform(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0));
raft::random::uniform(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
7 changes: 4 additions & 3 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
void SetUp() override
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/random/rmat_rectangular_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RmatGenTest : public ::testing::TestWithParam<RmatInputs> {
max_scale{std::max(params.r_scale, params.c_scale)}
{
theta.resize(4 * max_scale, stream);
uniform<float>(state, theta.data(), theta.size(), 0.0f, 1.0f, stream);
uniform<float>(handle, state, theta.data(), theta.size(), 0.0f, 1.0f);
normalize<float, float>(theta.data(),
theta.data(),
max_scale,
Expand Down Expand Up @@ -271,7 +271,7 @@ class RmatGenMdspanTest : public ::testing::TestWithParam<RmatInputs> {
max_scale{std::max(params.r_scale, params.c_scale)}
{
theta.resize(4 * max_scale, stream);
uniform<float>(state, theta.data(), theta.size(), 0.0f, 1.0f, stream);
uniform<float>(handle, state, theta.data(), theta.size(), 0.0f, 1.0f);
normalize<float, float>(theta.data(),
theta.data(),
max_scale,
Expand Down
12 changes: 5 additions & 7 deletions cpp/test/sparse/gram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
protected:
GramMatrixTest()
: params(GetParam()),
stream(0),
stream(resource::get_cuda_stream(handle)),
x1(0, stream),
x2(0, stream),
x1_csr_indptr(0, stream),
Expand All @@ -137,8 +137,6 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram(0, stream),
gram_host(0)
{
RAFT_CUDA_TRY(cudaStreamCreate(&stream));

if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; }
if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; }
if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; }
Expand All @@ -154,14 +152,14 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram_host.resize(gram.size());
std::fill(gram_host.begin(), gram_host.end(), 0);

raft::random::Rng r(42137ULL);
r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream);
r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream);
raft::random::RngState r(42137ULL);
raft::random::uniform(handle, r, x1.data(), x1.size(), math_t(0), math_t(1));
raft::random::uniform(handle, r, x2.data(), x2.size(), math_t(0), math_t(1));

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}

~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); }
~GramMatrixTest() override {}

int prepareCsr(math_t* dense, int n_rows, int ld, int* indptr, int* indices, math_t* data)
{
Expand Down
Loading