Skip to content

Commit

Permalink
Map blas/solver handles to streams, not threads and streams
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Dec 13, 2023
1 parent 8146481 commit b9727a4
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions examples/devblas_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stdexcept>
#include <optional>
#include <map>
#include <mutex>

#ifdef TTG_HAVE_CUDART

Expand All @@ -20,7 +21,8 @@
template<typename T = int>
inline const cublasHandle_t& cublas_handle(T _ = 0) {
using map_type = std::map<std::pair<int, cudaStream_t>, cublasHandle_t>;
static thread_local map_type handles;
static map_type handles;
static std::mutex handle_mtx;

auto d = ttg::device::current_device();
int device = 0; // assume 0 if we don't have a device
Expand All @@ -30,6 +32,7 @@ inline const cublasHandle_t& cublas_handle(T _ = 0) {

cudaStream_t stream = ttg::device::current_stream();

std::lock_guard g(handle_mtx);
map_type::iterator it;
if ((it = handles.find({device, stream})) == handles.end()){
cublasHandle_t handle;
Expand All @@ -53,7 +56,8 @@ template<typename T = int>
inline const cusolverDnHandle_t& cusolver_handle(T _ = 0) {

using map_type = std::map<std::pair<int, cudaStream_t>, cusolverDnHandle_t>;
static thread_local map_type handles;
static map_type handles;
static std::mutex handle_mtx;

auto d = ttg::device::current_device();
int device = 0; // assume 0 if we don't have a device
Expand All @@ -62,6 +66,7 @@ inline const cusolverDnHandle_t& cusolver_handle(T _ = 0) {
}
cudaStream_t stream = ttg::device::current_stream();

std::lock_guard g(handle_mtx);
map_type::iterator it;
if ((it = handles.find({device, stream})) == handles.end()){
cusolverDnHandle_t handle;
Expand Down Expand Up @@ -95,7 +100,8 @@ inline const cusolverDnHandle_t& cusolver_handle(T _ = 0) {
template<typename T = int>
inline const hipblasHandle_t& hipblas_handle(T _ = 0) {
using map_type = std::map<std::pair<int, hipStream_t>, hipblasHandle_t>;
static thread_local map_type handles;
static map_type handles;
static std::mutex handle_mtx;

auto d = ttg::device::current_device();
int device = 0; // assume 0 if we don't have a device
Expand All @@ -105,6 +111,7 @@ inline const hipblasHandle_t& hipblas_handle(T _ = 0) {

hipStream_t stream = ttg::device::current_stream();

std::lock_guard g(handle_mtx);
map_type::iterator it;
if ((it = handles.find({device, stream})) == handles.end()){
hipblasHandle_t handle;
Expand All @@ -128,7 +135,8 @@ inline const hipblasHandle_t& hipblas_handle(T _ = 0) {
template<typename T = int>
inline const hipsolverDnHandle_t& hipsolver_handle(T _ = 0) {
using map_type = std::map<std::pair<int, hipStream_t>, hipsolverDnHandle_t>;
static thread_local map_type handles;
static map_type handles;
static std::mutex handle_mtx;
auto d = ttg::device::current_device();
int device = 0; // assume 0 if we don't have a device
if (d.is_device()) {
Expand All @@ -137,6 +145,7 @@ inline const hipsolverDnHandle_t& hipsolver_handle(T _ = 0) {

hipStream_t stream = ttg::device::current_stream();

std::lock_guard g(handle_mtx);
map_type::iterator it;
if ((it = handles.find({device, stream})) == handles.end()){
hipsolverDnHandle_t handle;
Expand Down

0 comments on commit b9727a4

Please sign in to comment.