diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst index 2fea9395f..0e46e8dc0 100644 --- a/docs/building_the_project_with_dpcpp.rst +++ b/docs/building_the_project_with_dpcpp.rst @@ -287,6 +287,9 @@ portBLAS relies heavily on JIT compilation. This may cause time-outs on some systems. To avoid this issue, use ahead-of-time compilation through tuning targets or ``sycl-targets``. +The ``sycl::half`` type can be supported by setting +``-DPORTBLAS_ENABLE_HALF=ON``. + .. _build_for_portfft_dpcpp: Building for portFFT diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0beadc3ec..31c1b49ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,11 @@ foreach(domain ${TARGET_DOMAINS}) add_subdirectory(${domain}) endforeach() +if (PORTBLAS_ENABLE_HALF) + # Set the variable used for C++ macro + set(ENABLE_PORTBLAS_HALF ON) +endif() + # Generate header with enabled backends for testing configure_file(config.hpp.in "${CMAKE_CURRENT_BINARY_DIR}/oneapi/mkl/config.hpp.configured") file(GENERATE diff --git a/src/blas/backends/portblas/CMakeLists.txt b/src/blas/backends/portblas/CMakeLists.txt index 03fddbb38..730cca92b 100644 --- a/src/blas/backends/portblas/CMakeLists.txt +++ b/src/blas/backends/portblas/CMakeLists.txt @@ -20,9 +20,8 @@ set(LIB_NAME onemkl_blas_portblas) set(LIB_OBJ ${LIB_NAME}_obj) -if(NOT DEFINED PORTBLAS_TUNING_TARGET) - option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "") -endif() +option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "") +option(PORTBLAS_ENABLE_HALF "Enable half support with the portBLAS backend" OFF) # Parse compiler flags and return a list of SYCL targets # The list is empty if no targets are set @@ -152,6 +151,9 @@ if (NOT PORTBLAS_FOUND) # Following variable TUNING_TARGET will be used in portBLAS internal configuration set(TUNING_TARGET ${PORTBLAS_TUNING_TARGET}) set(BLAS_ENABLE_COMPLEX ON) + if (PORTBLAS_ENABLE_HALF) + set(BLAS_ENABLE_HALF ON) + endif() # Set the policy to forward variables to portBLAS configure step set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps") diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index 28c7ee5dc..29b31fe5b 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -210,7 +210,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for complex"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +#else + throw unimplemented("blas", "gemm_batch", " for half"); +#endif } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -219,7 +224,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +#else + throw unimplemented("blas", "gemm_batch", " for half"); +#endif } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -228,7 +238,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); + throw unimplemented("blas", "gemm_batch", " for int8"); } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -237,7 +247,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); + throw unimplemented("blas", "gemm_batch", " for int8"); } void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, @@ -686,7 +696,12 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const float **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -695,7 +710,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const double **b, std::int64_t *ldb, double *beta, double **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using double"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -705,7 +720,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, std::complex *beta, std::complex **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -715,7 +730,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, std::complex *beta, std::complex **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -724,7 +739,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -733,7 +757,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -742,7 +775,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -751,7 +784,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -785,7 +818,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t ldb, std::int64_t stride_b, std::complex beta, std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -795,7 +828,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t ldb, std::int64_t stride_b, std::complex beta, std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -805,7 +838,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -815,7 +854,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -825,7 +870,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -835,7 +880,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, diff --git a/src/blas/backends/portblas/portblas_level3_half.cpp b/src/blas/backends/portblas/portblas_level3_half.cpp index 0e42528fa..b3c2a0837 100644 --- a/src/blas/backends/portblas/portblas_level3_half.cpp +++ b/src/blas/backends/portblas/portblas_level3_half.cpp @@ -23,6 +23,7 @@ #include #endif +#include "portblas_common.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" @@ -32,19 +33,33 @@ namespace blas { namespace portblas { namespace column_major { +constexpr bool is_column_major() { + return true; +} + // BUFFER void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " half"); +#endif } void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " for different argument data types"); +#endif } // USM @@ -53,31 +68,56 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, sycl::half beta, sycl::half *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } } // namespace column_major + namespace row_major { +constexpr bool is_column_major() { + return false; +} + // BUFFER void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " half"); +#endif } void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " for different argument data types"); +#endif } // USM @@ -86,14 +126,24 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, sycl::half beta, sycl::half *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } } // namespace row_major diff --git a/src/config.hpp.in b/src/config.hpp.in index 5698abf9b..0c23e04da 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -32,6 +32,7 @@ #cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_CPU #cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_GPU #cmakedefine ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU +#cmakedefine ENABLE_PORTBLAS_HALF #cmakedefine ENABLE_PORTFFT_BACKEND #cmakedefine ENABLE_ROCBLAS_BACKEND #cmakedefine ENABLE_ROCFFT_BACKEND diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 5d607991e..8974d39c6 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -120,7 +120,10 @@ struct ref_type_info { // Random initialization. template static fp rand_scalar() { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wimplicit-const-int-float-conversion" return fp(std::rand()) / fp(RAND_MAX) - fp(0.5); +#pragma clang diagnostic pop } template static std::complex rand_complex_scalar() {