Skip to content

Commit

Permalink
Use out-of-place trmm functions if rocBLAS version >= 4.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsfriess committed Feb 20, 2024
1 parent 6774231 commit 0e397a1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cmake/FindrocBLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ list(APPEND CMAKE_PREFIX_PATH

find_package(HIP QUIET)
find_package(rocblas REQUIRED)
set(ROCBLAS_VERSION ${rocblas_VERSION})

# this is work around to avoid duplication half creation in both HIP and SYCL
add_compile_definitions(HIP_NO_HALF)
Expand All @@ -47,6 +48,8 @@ find_package_handle_standard_args(rocBLAS
HIP_LIBRARIES
ROCBLAS_INCLUDE_DIR
ROCBLAS_LIBRARIES
VERSION_VAR
ROCBLAS_VERSION
)
# OPENCL_INCLUDE_DIR
if(NOT TARGET ONEMKL::rocBLAS::rocBLAS)
Expand Down
4 changes: 4 additions & 0 deletions src/blas/backends/rocblas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ target_include_directories(${LIB_OBJ}
${ONEMKL_GENERATED_INCLUDE_PATH}
)

if (${ROCBLAS_VERSION} VERSION_GREATER_EQUAL "4.0")
target_compile_definitions(${LIB_OBJ} PRIVATE ROCBLAS_NO_LEGACY_TRMM)
endif()

if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl")
target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT})
target_compile_options(ONEMKL::SYCL::SYCL INTERFACE
Expand Down
18 changes: 18 additions & 0 deletions src/blas/backends/rocblas/rocblas_level3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,19 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe
auto a_ = sc.get_mem<rocDataType *>(a_acc);
auto b_ = sc.get_mem<rocDataType *>(b_acc);
rocblas_status err;

// rocblas version 4.0.0 removed the legacy BLAS trmm implementation
#ifdef ROCBLAS_NO_LEGACY_TRMM
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
#else
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
#endif
});
});
}
Expand Down Expand Up @@ -805,10 +814,19 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp
auto a_ = reinterpret_cast<const rocDataType *>(a);
auto b_ = reinterpret_cast<rocDataType *>(b);
rocblas_status err;

// rocblas version 4.0.0 removed the legacy BLAS trmm implementation
#ifdef ROCBLAS_NO_LEGACY_TRMM
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
#else
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
#endif
});
});

Expand Down

0 comments on commit 0e397a1

Please sign in to comment.