Skip to content

Commit

Permalink
[CUDA][HIP] Use device to get native context (#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
hdelan authored Apr 1, 2024
1 parent 3339418 commit 4635cad
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 18 deletions.
12 changes: 8 additions & 4 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::
: ih(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto device = queue.get_device();
auto desired = sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_);
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult err;
CUcontext desired;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
Expand Down Expand Up @@ -87,8 +88,11 @@ void ContextCallback(void *userData) {
}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue &queue) {
auto piPlacedContext_ = reinterpret_cast<pi_context>(
sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_));
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult cuErr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice);
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_);
Expand Down
12 changes: 8 additions & 4 deletions src/blas/backends/rocblas/rocblas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ RocblasScopedContextHandler::RocblasScopedContextHandler(sycl::queue queue,
: interop_h(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto device = queue.get_device();
auto desired = sycl::get_native<sycl::backend::ext_oneapi_hip>(*placedContext_);
auto hipDevice = ih.get_native_device<sycl::backend::ext_oneapi_hip>();
hipError_t err;
hipCtx_t desired;
HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_);
HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired);
Expand Down Expand Up @@ -103,8 +104,11 @@ void ContextCallback(void *userData) {
}

rocblas_handle RocblasScopedContextHandler::get_handle(const sycl::queue &queue) {
auto piPlacedContext_ = reinterpret_cast<pi_context>(
sycl::get_native<sycl::backend::ext_oneapi_hip>(*placedContext_));
auto hipDevice = interop_h.get_native_device<sycl::backend::ext_oneapi_hip>();
hipError_t hipErr;
hipCtx_t desired;
HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice);
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
hipStream_t streamId = get_stream(queue);
rocblas_status err;
auto it = handle_helper.rocblas_handle_container_mapper_.find(piPlacedContext_);
Expand Down
7 changes: 4 additions & 3 deletions src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
}
if (fix_context) {
// cufftDestroy changes the context so change it back.
CUcontext interopContext =
sycl::get_native<sycl::backend::ext_oneapi_cuda>(this->get_queue().get_context());
if (cuCtxSetCurrent(interopContext) != CUDA_SUCCESS) {
CUdevice interopDevice =
sycl::get_native<sycl::backend::ext_oneapi_cuda>(this->get_queue().get_device());
CUcontext interopContext;
if (cuDevicePrimaryCtxRetain(&interopContext, interopDevice) != CUDA_SUCCESS) {
throw mkl::exception("dft/backends/cufft", __FUNCTION__,
"Failed to change cuda context.");
}
Expand Down
12 changes: 8 additions & 4 deletions src/lapack/backends/cusolver/cusolver_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ CusolverScopedContextHandler::CusolverScopedContextHandler(sycl::queue queue,
: ih(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto device = queue.get_device();
auto desired = sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_);
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult err;
CUcontext desired;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
Expand Down Expand Up @@ -88,8 +89,11 @@ void ContextCallback(void *userData) {
}

cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &queue) {
auto piPlacedContext_ = reinterpret_cast<pi_context>(
sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_));
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult cuErr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice);
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
CUstream streamId = get_stream(queue);
cusolverStatus_t err;
auto it = handle_helper.cusolver_handle_mapper_.find(piPlacedContext_);
Expand Down
11 changes: 8 additions & 3 deletions src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ RocsolverScopedContextHandler::RocsolverScopedContextHandler(sycl::queue queue,
: ih(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto desired = sycl::get_native<sycl::backend::ext_oneapi_hip>(*placedContext_);
auto hipDevice = ih.get_native_device<sycl::backend::ext_oneapi_hip>();
hipError_t err;
hipCtx_t desired;
HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_);
HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired);
Expand Down Expand Up @@ -89,8 +91,11 @@ void ContextCallback(void *userData) {
}

rocblas_handle RocsolverScopedContextHandler::get_handle(const sycl::queue &queue) {
auto piPlacedContext_ = reinterpret_cast<pi_context>(
sycl::get_native<sycl::backend::ext_oneapi_hip>(*placedContext_));
auto hipDevice = ih.get_native_device<sycl::backend::ext_oneapi_hip>();
hipError_t hipErr;
hipCtx_t desired;
HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice);
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
hipStream_t streamId = get_stream(queue);
rocblas_status err;
auto it = handle_helper.rocsolver_handle_mapper_.find(piPlacedContext_);
Expand Down

0 comments on commit 4635cad

Please sign in to comment.