diff --git a/cpp/bench/common/ml_benchmark.hpp b/cpp/bench/common/ml_benchmark.hpp index e98007b055..fcdd8c535d 100644 --- a/cpp/bench/common/ml_benchmark.hpp +++ b/cpp/bench/common/ml_benchmark.hpp @@ -19,10 +19,9 @@ #include #include +#include #include -#include - #include #include @@ -165,7 +164,7 @@ class Fixture : public ::benchmark::Fixture { void alloc(T*& ptr, size_t len, bool init = false) { auto nBytes = len * sizeof(T); - auto d_alloc = rmm::mr::get_current_device_resource(); + auto d_alloc = raft::resource::get_current_device_resource(); ptr = (T*)d_alloc->allocate(nBytes, stream); if (init) { RAFT_CUDA_TRY(cudaMemsetAsync(ptr, 0, nBytes, stream)); } } @@ -173,7 +172,7 @@ class Fixture : public ::benchmark::Fixture { template void dealloc(T* ptr, size_t len) { - auto d_alloc = rmm::mr::get_current_device_resource(); + auto d_alloc = raft::resource::get_current_device_resource(); d_alloc->deallocate(ptr, len * sizeof(T), stream); } diff --git a/cpp/examples/symreg/symreg_example.cpp b/cpp/examples/symreg/symreg_example.cpp index 7624f7b66b..fed570050a 100644 --- a/cpp/examples/symreg/symreg_example.cpp +++ b/cpp/examples/symreg/symreg_example.cpp @@ -21,9 +21,10 @@ #include +#include +#include #include #include -#include #include #include @@ -200,11 +201,11 @@ int main(int argc, char* argv[]) /* ======================= Begin GPU memory allocation ======================= */ std::cout << "***************************************" << std::endl; - cudaStream_t stream; + rmm::cuda_stream stream; raft::handle_t handle{stream}; // Begin recording time - cudaEventRecord(start, stream); + CUDA_RT_CALL(cudaEventRecord(start, stream.value())); rmm::device_uvector dX_train(n_cols * n_train_rows, stream); rmm::device_uvector dy_train(n_train_rows, stream); @@ -215,46 +216,54 @@ int main(int argc, char* argv[]) rmm::device_uvector dy_pred(n_test_rows, stream); rmm::device_scalar d_score{stream}; - cg::program_t d_finalprogs; // pointer to last generation ASTs on device - CUDA_RT_CALL(cudaMemcpyAsync(dX_train.data(), X_train.data(), sizeof(float) * dX_train.size(), cudaMemcpyHostToDevice, - stream)); + stream.value())); CUDA_RT_CALL(cudaMemcpyAsync(dy_train.data(), y_train.data(), sizeof(float) * dy_train.size(), cudaMemcpyHostToDevice, - stream)); + stream.value())); CUDA_RT_CALL(cudaMemcpyAsync(dw_train.data(), w_train.data(), sizeof(float) * dw_train.size(), cudaMemcpyHostToDevice, - stream)); + stream.value())); - CUDA_RT_CALL(cudaMemcpyAsync( - dX_test.data(), X_test.data(), sizeof(float) * dX_test.size(), cudaMemcpyHostToDevice, stream)); + CUDA_RT_CALL(cudaMemcpyAsync(dX_test.data(), + X_test.data(), + sizeof(float) * dX_test.size(), + cudaMemcpyHostToDevice, + stream.value())); - CUDA_RT_CALL(cudaMemcpyAsync( - dy_test.data(), y_test.data(), sizeof(float) * dy_test.size(), cudaMemcpyHostToDevice, stream)); + CUDA_RT_CALL(cudaMemcpyAsync(dy_test.data(), + y_test.data(), + sizeof(float) * dy_test.size(), + cudaMemcpyHostToDevice, + stream.value())); - CUDA_RT_CALL(cudaMemcpyAsync( - dw_test.data(), w_test.data(), sizeof(float) * n_test_rows, cudaMemcpyHostToDevice, stream)); + CUDA_RT_CALL(cudaMemcpyAsync(dw_test.data(), + w_test.data(), + sizeof(float) * n_test_rows, + cudaMemcpyHostToDevice, + stream.value())); // Initialize AST - auto curr_mr = rmm::mr::get_current_device_resource(); - d_finalprogs = static_cast(curr_mr->allocate(params.population_size, stream)); + auto prog_buffer = rmm::device_buffer(params.population_size, stream); + // pointer to last generation ASTs on device + cg::program_t d_finalprogs = static_cast(prog_buffer.data()); std::vector> history; history.reserve(params.generations); - cudaEventRecord(stop, stream); - cudaEventSynchronize(stop); + CUDA_RT_CALL(cudaEventRecord(stop, stream.value())); + CUDA_RT_CALL(cudaEventSynchronize(stop)); float alloc_time; - cudaEventElapsedTime(&alloc_time, start, stop); + CUDA_RT_CALL(cudaEventElapsedTime(&alloc_time, start, stop)); std::cout << "Allocated device memory in " << std::setw(10) << alloc_time << "ms" << std::endl; @@ -263,7 +272,7 @@ int main(int argc, char* argv[]) std::cout << "***************************************" << std::endl; std::cout << std::setw(30) << "Beginning training for " << std::setw(15) << params.generations << " generations" << std::endl; - cudaEventRecord(start, stream); + CUDA_RT_CALL(cudaEventRecord(start, stream.value())); cg::symFit(handle, dX_train.data(), @@ -275,10 +284,10 @@ int main(int argc, char* argv[]) d_finalprogs, history); - cudaEventRecord(stop, stream); - cudaEventSynchronize(stop); + CUDA_RT_CALL(cudaEventRecord(stop, stream.value())); + CUDA_RT_CALL(cudaEventSynchronize(stop)); float training_time; - cudaEventElapsedTime(&training_time, start, stop); + CUDA_RT_CALL(cudaEventElapsedTime(&training_time, start, stop)); int n_gen = params.num_epochs; std::cout << std::setw(30) << "Convergence achieved in " << std::setw(15) << n_gen @@ -308,7 +317,7 @@ int main(int argc, char* argv[]) std::cout << "***************************************" << std::endl; std::cout << "Beginning Inference on test dataset " << std::endl; - cudaEventRecord(start, stream); + CUDA_RT_CALL(cudaEventRecord(start, stream.value())); cuml::genetic::symRegPredict( handle, dX_test.data(), n_test_rows, d_finalprogs + best_idx, dy_pred.data()); @@ -319,10 +328,10 @@ int main(int argc, char* argv[]) cuml::genetic::compute_metric( handle, n_test_rows, 1, dy_test.data(), dy_pred.data(), dw_test.data(), d_score.data(), params); - cudaEventRecord(stop, stream); - cudaEventSynchronize(stop); + CUDA_RT_CALL(cudaEventRecord(stop, stream.value())); + CUDA_RT_CALL(cudaEventSynchronize(stop)); float inference_time; - cudaEventElapsedTime(&inference_time, start, stop); + CUDA_RT_CALL(cudaEventElapsedTime(&inference_time, start, stop)); // Output fitness score std::cout << "Inference score = " << d_score.value(stream) << std::endl; @@ -336,9 +345,6 @@ int main(int argc, char* argv[]) std::copy(y_test.begin(), y_test.begin() + 5, std::ostream_iterator(std::cout, ";")); std::cout << std::endl; - /* ======================= Reset data ======================= */ - - curr_mr->deallocate(d_finalprogs, params.population_size, stream); CUDA_RT_CALL(cudaEventDestroy(start)); CUDA_RT_CALL(cudaEventDestroy(stop)); return 0; diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index 1586ca963c..b308fb2c23 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -16,10 +16,10 @@ #pragma once +#include #include #include -#include #include #include @@ -81,7 +81,7 @@ struct ARIMAParams { */ void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false) { - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); + rmm::device_async_resource_ref rmm_alloc = raft::resource::get_current_device_resource(); if (order.k && !tr) mu = (DataT*)rmm_alloc.allocate_async( batch_size * sizeof(DataT), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); @@ -115,7 +115,7 @@ struct ARIMAParams { */ void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false) { - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); + rmm::device_async_resource_ref rmm_alloc = raft::resource::get_current_device_resource(); if (order.k && !tr) rmm_alloc.deallocate_async( mu, batch_size * sizeof(DataT), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); diff --git a/cpp/src/genetic/genetic.cu b/cpp/src/genetic/genetic.cu index 947220ceb3..38558a1ed4 100644 --- a/cpp/src/genetic/genetic.cu +++ b/cpp/src/genetic/genetic.cu @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include -#include #include @@ -229,17 +229,17 @@ void parallel_evolve(const raft::handle_t& h, program tmp(h_nextprogs[i]); delete[] tmp.nodes; + auto mr = raft::resource::get_current_device_resource_ref(); + // Set current generation device nodes - tmp.nodes = (node*)rmm::mr::get_current_device_resource()->allocate( - h_nextprogs[i].len * sizeof(node), stream); + tmp.nodes = static_cast(mr.allocate_async(h_nextprogs[i].len * sizeof(node), stream)); raft::copy(tmp.nodes, h_nextprogs[i].nodes, h_nextprogs[i].len, stream); raft::copy(d_nextprogs + i, &tmp, 1, stream); if (generation > 1) { // Free device memory allocated to program nodes in previous generation raft::copy(&tmp, d_oldprogs + i, 1, stream); - rmm::mr::get_current_device_resource()->deallocate( - tmp.nodes, h_nextprogs[i].len * sizeof(node), stream); + mr.deallocate_async(tmp.nodes, h_nextprogs[i].len * sizeof(node), stream); } tmp.nodes = nullptr; @@ -408,8 +408,9 @@ void symFit(const raft::handle_t& handle, std::vector h_fitness(params.population_size, 0.0f); program_t d_currprogs; // pointer to current programs - d_currprogs = (program_t)rmm::mr::get_current_device_resource()->allocate( - params.population_size * sizeof(program), stream); + auto mr = raft::resource::get_current_device_resource_ref(); + d_currprogs = + static_cast(mr.allocate_async(params.population_size * sizeof(program), stream)); program_t d_nextprogs = final_progs; // Reuse memory already allocated for final_progs final_progs = nullptr; @@ -490,8 +491,7 @@ void symFit(const raft::handle_t& handle, if (growAuto) { params.terminalRatio = 0.0f; } // Deallocate the previous generation device memory - rmm::mr::get_current_device_resource()->deallocate( - d_nextprogs, params.population_size * sizeof(program), stream); + mr.deallocate_async(d_nextprogs, params.population_size * sizeof(program), stream); d_currprogs = nullptr; d_nextprogs = nullptr; } diff --git a/cpp/src/svm/linear.cu b/cpp/src/svm/linear.cu index ac1d561ed0..96f24e5014 100644 --- a/cpp/src/svm/linear.cu +++ b/cpp/src/svm/linear.cu @@ -329,13 +329,13 @@ LinearSVMModel LinearSVMModel::allocate(const raft::handle_t& handle, const std::size_t nClasses) { auto stream = handle.get_stream(); - auto res = rmm::mr::get_current_device_resource(); + auto res = raft::resource::get_current_device_resource_ref(); const std::size_t coefRows = nCols + params.fit_intercept; const std::size_t coefCols = nClasses <= 2 ? 1 : nClasses; const std::size_t wSize = coefRows * coefCols; const std::size_t cSize = nClasses >= 2 ? nClasses : 0; const std::size_t pSize = params.probability ? 2 * coefCols : 0; - auto bytes = static_cast(res->allocate(sizeof(T) * (wSize + cSize + pSize), stream)); + auto bytes = static_cast(res.allocate_async(sizeof(T) * (wSize + cSize + pSize), stream)); return LinearSVMModel{/* .w */ bytes, /* .classes */ cSize > 0 ? bytes + wSize : nullptr, /* .probScale */ pSize > 0 ? bytes + wSize + cSize : nullptr, @@ -347,13 +347,13 @@ template void LinearSVMModel::free(const raft::handle_t& handle, LinearSVMModel& model) { auto stream = handle.get_stream(); - auto res = rmm::mr::get_current_device_resource(); + auto res = raft::resource::get_current_device_resource_ref(); const std::size_t coefRows = model.coefRows; const std::size_t coefCols = model.coefCols(); const std::size_t wSize = coefRows * coefCols; const std::size_t cSize = model.nClasses; const std::size_t pSize = model.probScale == nullptr ? 2 * coefCols : 0; - res->deallocate(model.w, sizeof(T) * (wSize + cSize + pSize), stream); + res.deallocate_async(model.w, sizeof(T) * (wSize + cSize + pSize), stream); model.w = nullptr; model.classes = nullptr; model.probScale = nullptr; diff --git a/cpp/src/svm/results.cuh b/cpp/src/svm/results.cuh index f33e8c4552..2551004a48 100644 --- a/cpp/src/svm/results.cuh +++ b/cpp/src/svm/results.cuh @@ -20,8 +20,10 @@ #include "ws_util.cuh" #include +#include #include +#include #include #include #include @@ -31,7 +33,6 @@ #include #include -#include #include #include @@ -73,7 +74,7 @@ class Results { const math_t* y, const math_t* C, SvmType svmType) - : rmm_alloc(rmm::mr::get_current_device_resource()), + : rmm_alloc(raft::resource::get_current_device_resource_ref()), stream(handle.get_stream()), handle(handle), n_rows(n_rows), diff --git a/cpp/src/svm/sparse_util.cuh b/cpp/src/svm/sparse_util.cuh index c4d0b277e9..5981c45018 100644 --- a/cpp/src/svm/sparse_util.cuh +++ b/cpp/src/svm/sparse_util.cuh @@ -762,14 +762,15 @@ void extractRows(raft::device_csr_matrix_view matrix_in, math_t* data_in = matrix_in.get_elements().data(); // allocate indptr - auto* rmm_alloc = rmm::mr::get_current_device_resource(); - *indptr_out = (int*)rmm_alloc->allocate((num_indices + 1) * sizeof(int), stream); + auto rmm_alloc = raft::resource::get_current_device_resource_ref(); + *indptr_out = + static_cast(rmm_alloc.allocate_async((num_indices + 1) * sizeof(int), stream)); *nnz = computeIndptrForSubset(indptr_in, *indptr_out, row_indices, num_indices, stream); // allocate indices, data - *indices_out = (int*)rmm_alloc->allocate(*nnz * sizeof(int), stream); - *data_out = (math_t*)rmm_alloc->allocate(*nnz * sizeof(math_t), stream); + *indices_out = static_cast(rmm_alloc.allocate_async(*nnz * sizeof(int), stream)); + *data_out = static_cast(rmm_alloc.allocate_async(*nnz * sizeof(math_t), stream)); // copy with 1 warp per row for now, blocksize 256 const dim3 bs(32, 8, 1); diff --git a/cpp/src/svm/svc_impl.cuh b/cpp/src/svm/svc_impl.cuh index 3bd27dc6e4..a798e1f3f3 100644 --- a/cpp/src/svm/svc_impl.cuh +++ b/cpp/src/svm/svc_impl.cuh @@ -28,13 +28,13 @@ #include #include +#include #include #include #include #include #include -#include #include #include @@ -71,9 +71,9 @@ void svcFitX(const raft::handle_t& handle, cudaStream_t stream = handle_impl.get_stream(); { rmm::device_uvector unique_labels(0, stream); - model.n_classes = raft::label::getUniquelabels(unique_labels, labels, n_rows, stream); - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); - model.unique_labels = (math_t*)rmm_alloc.allocate_async( + model.n_classes = raft::label::getUniquelabels(unique_labels, labels, n_rows, stream); + auto rmm_alloc = raft::resource::get_current_device_resource_ref(); + model.unique_labels = (math_t*)rmm_alloc.allocate_async( model.n_classes * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); raft::copy(model.unique_labels, unique_labels.data(), model.n_classes, stream); handle_impl.sync_stream(stream); @@ -356,7 +356,7 @@ template void svmFreeBuffers(const raft::handle_t& handle, SvmModel& m) { cudaStream_t stream = handle.get_stream(); - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); + rmm::device_async_resource_ref rmm_alloc = raft::resource::get_current_device_resource_ref(); if (m.dual_coefs) rmm_alloc.deallocate_async( m.dual_coefs, m.n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); diff --git a/cpp/test/mg/knn.cu b/cpp/test/mg/knn.cu index 30ada2f90e..517a53ab79 100644 --- a/cpp/test/mg/knn.cu +++ b/cpp/test/mg/knn.cu @@ -20,11 +20,10 @@ #include #include +#include #include #include -#include - #include #include @@ -63,8 +62,8 @@ class BruteForceKNNTest : public ::testing::TestWithParam { bool runTest(const KNNParams& params) { raft::comms::initialize_mpi_comms(&handle, MPI_COMM_WORLD); - const auto& comm = handle.get_comms(); - const auto allocator = rmm::mr::get_current_device_resource(); + const auto& comm = handle.get_comms(); + auto allocator = raft::resource::get_current_device_resource_ref(); cudaStream_t stream = handle.get_stream(); @@ -112,14 +111,14 @@ class BruteForceKNNTest : public ::testing::TestWithParam { std::vector out_d_parts; std::vector*> out_i_parts; for (int i = 0; i < query_parts_per_rank; i++) { - float* q = - (float*)allocator.get()->allocate(params.min_rows * params.n_cols * sizeof(float*), stream); + float* q = static_cast( + allocator.allocate_async(params.min_rows * params.n_cols * sizeof(float*), stream)); - float* o = - (float*)allocator.get()->allocate(params.min_rows * params.k * sizeof(float*), stream); + float* o = static_cast( + allocator.allocate_async(params.min_rows * params.k * sizeof(float*), stream)); - int64_t* ind = - (int64_t*)allocator.get()->allocate(params.min_rows * params.k * sizeof(int64_t), stream); + int64_t* ind = static_cast( + allocator.allocate_async(params.min_rows * params.k * sizeof(int64_t), stream)); Matrix::Data* query_d = new Matrix::Data(q, params.min_rows * params.n_cols); diff --git a/cpp/test/sg/genetic/evolution_test.cu b/cpp/test/sg/genetic/evolution_test.cu index 526acb5280..911f0583dd 100644 --- a/cpp/test/sg/genetic/evolution_test.cu +++ b/cpp/test/sg/genetic/evolution_test.cu @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -268,9 +269,8 @@ class GeneticEvolutionTest : public ::testing::Test { TEST_F(GeneticEvolutionTest, SymReg) { MLCommon::CompareApprox compApprox(tolerance); - program_t final_progs; - final_progs = (program_t)rmm::mr::get_current_device_resource()->allocate( - hyper_params.population_size * sizeof(program), stream); + auto prog_buffer = rmm::device_buffer(hyper_params.population_size * sizeof(program), stream); + program_t final_progs = static_cast(prog_buffer.data()); std::vector> history; history.reserve(hyper_params.generations); @@ -337,12 +337,11 @@ TEST_F(GeneticEvolutionTest, SymReg) for (auto i = 0; i < hyper_params.population_size; ++i) { program tmp = program(); raft::copy(&tmp, final_progs + i, 1, stream); - rmm::mr::get_current_device_resource()->deallocate(tmp.nodes, tmp.len * sizeof(node), stream); + // TODO: why are we deallocating something that wasn't allocated here? + raft::resource::get_current_device_resource_ref().deallocate_async( + tmp.nodes, tmp.len * sizeof(node), stream); tmp.nodes = nullptr; } - // deallocate the final programs from device memory - rmm::mr::get_current_device_resource()->deallocate( - final_progs, hyper_params.population_size * sizeof(program), stream); ASSERT_TRUE(compApprox(history[n_gen - 1][best_idx].raw_fitness_, 0.0036f)); std::cout << "Some Predicted test values:" << std::endl; diff --git a/cpp/test/sg/genetic/program_test.cu b/cpp/test/sg/genetic/program_test.cu index 1205baf9d9..54b3ed0d88 100644 --- a/cpp/test/sg/genetic/program_test.cu +++ b/cpp/test/sg/genetic/program_test.cu @@ -22,8 +22,8 @@ #include #include +#include #include -#include #include #include @@ -38,7 +38,10 @@ namespace genetic { class GeneticProgramTest : public ::testing::Test { public: GeneticProgramTest() - : d_data(0, cudaStream_t(0)), + : nodes1_buffer(7 * sizeof(node), stream), + nodes2_buffer(7 * sizeof(node), stream), + progs_buffer(2 * sizeof(program), stream), + d_data(0, cudaStream_t(0)), d_y(0, cudaStream_t(0)), d_lYpred(0, cudaStream_t(0)), d_lY(0, cudaStream_t(0)), @@ -101,10 +104,10 @@ class GeneticProgramTest : public ::testing::Test { d_lY.resize(250, stream); d_lunitW.resize(250, stream); d_lW.resize(250, stream); - d_nodes1 = (node*)rmm::mr::get_current_device_resource()->allocate(7 * sizeof(node), stream); - d_nodes2 = (node*)rmm::mr::get_current_device_resource()->allocate(7 * sizeof(node), stream); - d_progs = - (program_t)rmm::mr::get_current_device_resource()->allocate(2 * sizeof(program), stream); + + d_nodes1 = static_cast(nodes1_buffer.data()); + d_nodes2 = static_cast(nodes2_buffer.data()); + d_progs = static_cast(progs_buffer.data()); RAFT_CUDA_TRY(cudaMemcpyAsync( d_lYpred.data(), h_lYpred.data(), 500 * sizeof(float), cudaMemcpyHostToDevice, stream)); @@ -155,12 +158,7 @@ class GeneticProgramTest : public ::testing::Test { dyp2.data(), hyp2.data(), 10 * sizeof(float), cudaMemcpyHostToDevice, stream)); } - void TearDown() override - { - rmm::mr::get_current_device_resource()->deallocate(d_nodes1, 7 * sizeof(node), stream); - rmm::mr::get_current_device_resource()->deallocate(d_nodes2, 7 * sizeof(node), stream); - rmm::mr::get_current_device_resource()->deallocate(d_progs, 2 * sizeof(program), stream); - } + void TearDown() override {} raft::handle_t handle; cudaStream_t stream; @@ -348,6 +346,9 @@ class GeneticProgramTest : public ::testing::Test { node* d_nodes1; node* d_nodes2; program_t d_progs; + rmm::device_buffer nodes1_buffer; + rmm::device_buffer nodes2_buffer; + rmm::device_buffer progs_buffer; rmm::device_uvector d_data; rmm::device_uvector d_y; rmm::device_uvector d_lYpred; diff --git a/cpp/test/sg/svc_test.cu b/cpp/test/sg/svc_test.cu index 0caad107d5..404b22bfa9 100644 --- a/cpp/test/sg/svc_test.cu +++ b/cpp/test/sg/svc_test.cu @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -502,8 +503,8 @@ class GetResultsTest : public ::testing::Test { protected: void FreeDenseSupport() { - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); - auto stream = this->handle.get_stream(); + auto rmm_alloc = raft::resource::get_current_device_resource_ref(); + auto stream = this->handle.get_stream(); rmm_alloc.deallocate_async(support_matrix.data, n_coefs * n_cols * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT,