Skip to content

Commit

Permalink
We need to set the context properly before destroying cublasHandles
Browse files Browse the repository at this point in the history
  • Loading branch information
konradkusiak97 committed Nov 8, 2024
1 parent 46a2661 commit 5996320
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
13 changes: 11 additions & 2 deletions src/blas/backends/cublas/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
#ifndef CUBLAS_HANDLE_HPP
#define CUBLAS_HANDLE_HPP
#include <unordered_map>
#include "cublas_helper.hpp"

namespace oneapi {
namespace mkl {
namespace blas {
namespace cublas {

template <typename T>
struct cublas_handle {
using handle_container_t = std::unordered_map<T, cublasHandle_t>;
using handle_container_t = std::unordered_map<CUdevice, cublasHandle_t>;
handle_container_t cublas_handle_mapper_{};
~cublas_handle() noexcept(false) {
CUresult err;
CUcontext original;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original);
for (auto& handle_pair : cublas_handle_mapper_) {
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, handle_pair.first);
if (original != desired) {
// Sets the desired context as the active one for the thread in order to destroy its corresponding cublasHandle_t.
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
}
cublasStatus_t err;
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second);
}
Expand Down
3 changes: 1 addition & 2 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ namespace cublas {
* takes place if no other element in the container has a key equivalent to
* the one being emplaced (keys in a map container are unique).
*/
thread_local cublas_handle<CUdevice> CublasScopedContextHandler::handle_helper =
cublas_handle<CUdevice>{};
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}

Expand Down
2 changes: 1 addition & 1 deletion src/blas/backends/cublas/cublas_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ the handle must be destroyed when the context goes out of scope. This will bind

class CublasScopedContextHandler {
sycl::interop_handle& ih;
static thread_local cublas_handle<CUdevice> handle_helper;
static thread_local cublas_handle handle_helper;
CUstream get_stream(const sycl::queue& queue);
sycl::context get_context(const sycl::queue& queue);

Expand Down
4 changes: 2 additions & 2 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ namespace mkl {
namespace blas {
namespace cublas {

thread_local cublas_handle<int> CublasScopedContextHandler::handle_helper = cublas_handle<int>{};
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: interop_h(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
sycl::device device = queue.get_device();
int current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUdevice current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(current_device);
Expand Down
2 changes: 1 addition & 1 deletion src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ the handle must be destroyed when the context goes out of scope. This will bind

class CublasScopedContextHandler {
sycl::interop_handle interop_h;
static thread_local cublas_handle<int> handle_helper;
static thread_local cublas_handle handle_helper;
sycl::context get_context(const sycl::queue& queue);
CUstream get_stream(const sycl::queue& queue);

Expand Down

0 comments on commit 5996320

Please sign in to comment.