Skip to content

Commit

Permalink
[BLAS][portBLAS] Add bindings for half and some gemm_batch group APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 30, 2024
1 parent afb9d5c commit 3f683b3
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 21 deletions.
3 changes: 3 additions & 0 deletions docs/building_the_project_with_dpcpp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ portBLAS relies heavily on JIT compilation. This may cause time-outs on some
systems. To avoid this issue, use ahead-of-time compilation through tuning
targets or ``sycl-targets``.

The ``sycl::half`` type can be supported by setting
``-DPORTBLAS_ENABLE_HALF=ON``.

.. _build_for_portfft_dpcpp:

Building for portFFT
Expand Down
5 changes: 5 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ foreach(domain ${TARGET_DOMAINS})
add_subdirectory(${domain})
endforeach()

if (PORTBLAS_ENABLE_HALF)
# Set the variable used for C++ macro
set(ENABLE_PORTBLAS_HALF ON)
endif()

# Generate header with enabled backends for testing
configure_file(config.hpp.in "${CMAKE_CURRENT_BINARY_DIR}/oneapi/mkl/config.hpp.configured")
file(GENERATE
Expand Down
8 changes: 5 additions & 3 deletions src/blas/backends/portblas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
set(LIB_NAME onemkl_blas_portblas)
set(LIB_OBJ ${LIB_NAME}_obj)

if(NOT DEFINED PORTBLAS_TUNING_TARGET)
option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "")
endif()
option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "")
option(PORTBLAS_ENABLE_HALF "Enable half support with the portBLAS backend" OFF)

# Parse compiler flags and return a list of SYCL targets
# The list is empty if no targets are set
Expand Down Expand Up @@ -152,6 +151,9 @@ if (NOT PORTBLAS_FOUND)
# Following variable TUNING_TARGET will be used in portBLAS internal configuration
set(TUNING_TARGET ${PORTBLAS_TUNING_TARGET})
set(BLAS_ENABLE_COMPLEX ON)
if (PORTBLAS_ENABLE_HALF)
set(BLAS_ENABLE_HALF ON)
endif()
# Set the policy to forward variables to portBLAS configure step
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps")
Expand Down
81 changes: 63 additions & 18 deletions src/blas/backends/portblas/portblas_batch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::
sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, std::int64_t stride_b,
sycl::half beta, sycl::buffer<sycl::half, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size) {
throw unimplemented("blas", "gemm_batch", " for complex");
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
#else
throw unimplemented("blas", "gemm_batch", " for half");
#endif
}

void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
Expand All @@ -219,7 +224,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::
sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
throw unimplemented("blas", "gemm_batch", " for unsupported dtype");
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
#else
throw unimplemented("blas", "gemm_batch", " for half");
#endif
}

void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
Expand All @@ -228,7 +238,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::
sycl::buffer<std::int8_t, 1> &b, std::int64_t ldb, std::int64_t stride_b,
float beta, sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
throw unimplemented("blas", "gemm_batch", " for unsupported dtype");
throw unimplemented("blas", "gemm_batch", " for int8");
}

void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
Expand All @@ -237,7 +247,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::
sycl::buffer<std::int8_t, 1> &b, std::int64_t ldb, std::int64_t stride_b,
float beta, sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size) {
throw unimplemented("blas", "gemm_batch", " for unsupported dtype");
throw unimplemented("blas", "gemm_batch", " for int8");
}

void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower,
Expand Down Expand Up @@ -686,7 +696,12 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const float **b, std::int64_t *ldb, float *beta, float **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
if (group_count != 1) {
throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1");
}
CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0],
alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0],
::blas::gemm_batch_type_t::strided, dependencies);
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -695,7 +710,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const double **b, std::int64_t *ldb, double *beta, double **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using double");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -705,7 +720,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
std::complex<float> *beta, std::complex<float> **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using complex");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -715,7 +730,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
std::complex<double> *beta, std::complex<double> **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using complex");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -724,7 +739,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
#ifdef ENABLE_PORTBLAS_HALF
if (group_count != 1) {
throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1");
}
CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0],
alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0],
::blas::gemm_batch_type_t::strided, dependencies);
#else
throw unimplemented("blas", "gemm_batch", " for USM using half");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -733,7 +757,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const sycl::half **b, std::int64_t *ldb, float *beta, float **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
#ifdef ENABLE_PORTBLAS_HALF
if (group_count != 1) {
throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1");
}
CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0],
alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0],
::blas::gemm_batch_type_t::strided, dependencies);
#else
throw unimplemented("blas", "gemm_batch", " for USM using half");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -742,7 +775,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const std::int8_t **b, std::int64_t *ldb, float *beta, float **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using int8");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
Expand All @@ -751,7 +784,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,
const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using int8");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand Down Expand Up @@ -785,7 +818,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t ldb, std::int64_t stride_b, std::complex<float> beta,
std::complex<float> *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using complex");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -795,7 +828,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t ldb, std::int64_t stride_b, std::complex<double> beta,
std::complex<double> *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using complex");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -805,7 +838,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
#else
throw unimplemented("blas", "gemm_batch", " for USM using half");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -815,7 +854,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, float beta, float *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
#else
throw unimplemented("blas", "gemm_batch", " for USM using half");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -825,7 +870,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, float beta, float *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using int8");
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -835,7 +880,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
throw unimplemented("blas", "gemm_batch", " for USM using int8");
}

sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right,
Expand Down
50 changes: 50 additions & 0 deletions src/blas/backends/portblas/portblas_level3_half.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <CL/sycl.hpp>
#endif

#include "portblas_common.hpp"
#include "oneapi/mkl/exceptions.hpp"
#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp"

Expand All @@ -32,19 +33,33 @@ namespace blas {
namespace portblas {
namespace column_major {

constexpr bool is_column_major() {
return true;
}

// BUFFER
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, sycl::half beta, sycl::buffer<sycl::half, 1> &c, std::int64_t ldc) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
#else
throw unimplemented("blas", "gemm", " half");
#endif
}

void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, float beta, sycl::buffer<float, 1> &c, std::int64_t ldc) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
#else
throw unimplemented("blas", "gemm", " for different argument data types");
#endif
}

// USM
Expand All @@ -53,31 +68,56 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:
const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb,
sycl::half beta, sycl::half *c, std::int64_t ldc,
const std::vector<sycl::event> &dependencies) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc, dependencies);
#else
throw unimplemented("blas", "gemm", " for USM");
#endif
}

sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a,
std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c,
std::int64_t ldc, const std::vector<sycl::event> &dependencies) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc, dependencies);
#else
throw unimplemented("blas", "gemm", " for USM");
#endif
}
} // namespace column_major

namespace row_major {

constexpr bool is_column_major() {
return false;
}

// BUFFER
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, sycl::half beta, sycl::buffer<sycl::half, 1> &c, std::int64_t ldc) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
#else
throw unimplemented("blas", "gemm", " half");
#endif
}

void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, float beta, sycl::buffer<float, 1> &c, std::int64_t ldc) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
#else
throw unimplemented("blas", "gemm", " for different argument data types");
#endif
}

// USM
Expand All @@ -86,14 +126,24 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:
const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb,
sycl::half beta, sycl::half *c, std::int64_t ldc,
const std::vector<sycl::event> &dependencies) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc, dependencies);
#else
throw unimplemented("blas", "gemm", " for USM");
#endif
}

sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a,
std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c,
std::int64_t ldc, const std::vector<sycl::event> &dependencies) {
#ifdef ENABLE_PORTBLAS_HALF
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc, dependencies);
#else
throw unimplemented("blas", "gemm", " for USM");
#endif
}

} // namespace row_major
Expand Down
1 change: 1 addition & 0 deletions src/config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_CPU
#cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_GPU
#cmakedefine ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU
#cmakedefine ENABLE_PORTBLAS_HALF
#cmakedefine ENABLE_PORTFFT_BACKEND
#cmakedefine ENABLE_ROCBLAS_BACKEND
#cmakedefine ENABLE_ROCFFT_BACKEND
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/blas/include/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ struct ref_type_info<int32_t> {
// Random initialization.
template <typename fp>
static fp rand_scalar() {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wimplicit-const-int-float-conversion"
return fp(std::rand()) / fp(RAND_MAX) - fp(0.5);
#pragma clang diagnostic pop
}
template <typename fp>
static std::complex<fp> rand_complex_scalar() {
Expand Down

0 comments on commit 3f683b3

Please sign in to comment.