From 52fced158f05b8f1b6a0447dbd5061e99cad0769 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:00:42 +0100 Subject: [PATCH 01/30] Add new interface --- include/oneapi/mkl/blas.hxx | 111 ++++++++++ .../mkl/blas/detail/blas_ct_backends.hxx | 75 +++++++ .../oneapi/mkl/blas/detail/blas_loader.hxx | 68 ++++++ .../mkl/blas/detail/onemkl_blas_backends.hxx | 77 +++++++ src/blas/backends/backend_wrappers.cxx | 9 + src/blas/blas_loader.cpp | 198 ++++++++++++++++++ src/blas/function_table.hpp | 116 ++++++++++ 7 files changed, 654 insertions(+) diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index 5a703fea2..374585912 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + static inline void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -2246,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, @@ -2312,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + static inline sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index 784eeafee..afebb93c3 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -464,6 +464,30 @@ static inline void gemm_batch(backend_selector selector, trans sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + static inline void spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, float beta, @@ -1870,6 +1894,30 @@ static inline sycl::event gemm_batch(backend_selector selector std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, std::int32_t **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + static inline sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1911,6 +1959,33 @@ static inline sycl::event gemm_batch( sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + static inline sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index d964d0024..98d93b2ad 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -124,6 +124,27 @@ ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, tr std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); ONEMKL_EXPORT void syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, @@ -1227,6 +1248,29 @@ ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &qu sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1263,6 +1307,30 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index fe81ae5aa..fbb64a6a0 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -973,6 +973,30 @@ ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + ONEMKL_EXPORT void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -2558,6 +2582,32 @@ ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -2599,6 +2649,33 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, diff --git a/src/blas/backends/backend_wrappers.cxx b/src/blas/backends/backend_wrappers.cxx index 34af9cf2f..62f6ced13 100644 --- a/src/blas/backends/backend_wrappers.cxx +++ b/src/blas/backends/backend_wrappers.cxx @@ -200,6 +200,9 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, @@ -455,6 +458,12 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 490d730a7..9022900fc 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -1342,6 +1342,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_hsgemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_isgemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].column_major_iigemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -3405,6 +3438,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_hsgemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_isgemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_iigemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -3463,6 +3529,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_hsgemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_isgemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_iigemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, @@ -5177,6 +5276,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_hsgemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_isgemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].row_major_iigemm_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -7236,6 +7368,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_hsgemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_isgemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_iigemm_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -7294,6 +7459,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_hsgemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_isgemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_iigemm_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index c9d640b1c..57490523e 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -869,6 +869,26 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_hsgemm_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_isgemm_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*column_major_iigemm_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*column_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -2180,6 +2200,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*column_major_hsgemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_isgemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_iigemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*column_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -2213,6 +2251,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_hsgemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_isgemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_iigemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*column_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, @@ -3269,6 +3325,30 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_hsgemm_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_isgemm_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_iigemm_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); void (*row_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -4581,6 +4661,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*row_major_hsgemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_isgemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_iigemm_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*row_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -4614,6 +4712,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_hsgemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_isgemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_iigemm_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*row_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, From 9cb61556b7b9bd5fe22e8b27178b1ee8e64baad8 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:03:13 +0100 Subject: [PATCH 02/30] Add new dtype testing for gemm_batch --- .../blas/batch/gemm_batch_stride.cpp | 86 +++++++++----- .../blas/batch/gemm_batch_stride_usm.cpp | 93 ++++++++++----- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 107 +++++++++++++----- tests/unit_tests/blas/include/test_common.hpp | 10 +- 4 files changed, 208 insertions(+), 88 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index d194e2007..76b477181 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -47,13 +47,13 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Prepare data. int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; batch_size = 1 + std::rand() % 20; @@ -63,10 +63,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); + alpha = rand_scalar(); + beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { + if ((std::is_same::value) || (std::is_same::value)) { transa = (oneapi::mkl::transpose)(std::rand() % 2); transb = (oneapi::mkl::transpose)(std::rand() % 2); } @@ -99,8 +99,12 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - vector> A(stride_a * batch_size), B(stride_b * batch_size); - vector> C(stride_c * batch_size), C_ref(stride_c * batch_size); + vector> A(stride_a * batch_size); + vector> B(stride_b * batch_size); + vector> C(stride_c * batch_size), + C_cast_ref(stride_c * batch_size); + vector> A_ref(stride_a * batch_size), B_ref(stride_b * batch_size), + C_ref(stride_c * batch_size); for (i = 0; i < batch_size; i++) { rand_matrix(A.data() + stride_a * i, layout, transa, m, k, lda); @@ -108,10 +112,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { rand_matrix(C.data() + stride_c * i, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); } - C_ref = C; + for (size_t i = 0; i < A.size(); ++i) + A_ref[i] = A[i]; + for (size_t i = 0; i < B.size(); ++i) + B_ref[i] = B[i]; + for (size_t i = 0; i < C.size(); ++i) + C_ref[i] = C[i]; // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -121,12 +130,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -147,9 +157,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { queue main_queue(*dev, exception_handler); - buffer A_buffer(A.data(), range<1>(A.size())); - buffer B_buffer(B.data(), range<1>(B.size())); - buffer C_buffer(C.data(), range<1>(C.size())); + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + buffer C_buffer(C.data(), range<1>(C.size())); try { #ifdef CALL_RT_API @@ -183,6 +193,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } #endif + main_queue.wait_and_throw(); } catch (exception const &e) { std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n" @@ -200,11 +211,14 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. + constexpr int tol_scalar = std::is_same_v ? 10 : 40; + for (size_t i = 0; i < C_ref.size(); ++i) + C_cast_ref[i] = C_ref[i]; auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = - check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, 10 * k, std::cout); + bool good = check_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, tol_scalar * k, + std::cout); return (int)good; } @@ -213,29 +227,49 @@ class GemmBatchStrideTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, RealHalfRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, RealIntRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, RealIntPrecision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideTestSuite, GemmBatchStrideTests, diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index b0d8ec90b..16959a9cd 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -72,7 +72,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; @@ -83,9 +83,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { + alpha = rand_scalar(); + beta = rand_scalar(); + if ((std::is_same::value) || (std::is_same::value)) { transa = (oneapi::mkl::transpose)(std::rand() % 2); transb = (oneapi::mkl::transpose)(std::rand() % 2); } @@ -118,18 +118,27 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - auto ua = usm_allocator(cxt, *dev); - vector A(ua), B(ua), C(ua), C_ref(ua); + auto ua = usm_allocator(cxt, *dev); + auto ub = usm_allocator(cxt, *dev); + auto uc = usm_allocator(cxt, *dev); + auto us = usm_allocator(cxt, *dev); + vector A(ua); + vector B(ub); + vector C(uc), C_cast_ref(us); + vector A_ref(ua), B_ref(ub), C_ref(us); A.resize(stride_a * batch_size); B.resize(stride_b * batch_size); C.resize(stride_c * batch_size); + A_ref.resize(stride_c * batch_size); + B_ref.resize(stride_c * batch_size); C_ref.resize(stride_c * batch_size); + C_cast_ref.resize(stride_c * batch_size); - fp **a_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **b_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + Ta **a_array = (Ta **)oneapi::mkl::malloc_shared(64, sizeof(Ta *) * batch_size, *dev, cxt); + Tb **b_array = (Tb **)oneapi::mkl::malloc_shared(64, sizeof(Tb *) * batch_size, *dev, cxt); + Tc **c_array = (Tc **)oneapi::mkl::malloc_shared(64, sizeof(Tc *) * batch_size, *dev, cxt); + Ts **c_ref_array = (Ts **)oneapi::mkl::malloc_shared(64, sizeof(Ts *) * batch_size, *dev, cxt); if ((a_array == NULL) || (b_array == NULL) || (c_array == NULL) || (c_ref_array == NULL)) { std::cout << "Error cannot allocate arrays of pointers\n"; @@ -153,11 +162,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { stride_b * batch_size, 1, stride_b * batch_size); rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); + copy_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size, A_ref); + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size, B_ref); copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref); // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -166,12 +179,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int ldc_ref = (int)ldc; int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -191,7 +205,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { case oneapi::mkl::layout::col_major: @@ -208,7 +222,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -231,8 +245,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, 10 * k, std::cout); + constexpr int tol_scalar = std::is_same_v ? 10 : 40; + + for (size_t i = 0; i < C_ref.size(); ++i) + C_cast_ref[i] = C_ref[i]; + bool good = + check_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, + stride_c * batch_size, tol_scalar * k, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); @@ -246,29 +265,49 @@ class GemmBatchStrideUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, RealHalfRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, RealIntRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, RealIntRealIntPrecision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideUsmTestSuite, GemmBatchStrideUsmTests, diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 58963a889..af38a10b0 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -76,8 +76,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { auto uatranspose = usm_allocator(cxt, *dev); vector transa(uatranspose), transb(uatranspose); - auto uafp = usm_allocator(cxt, *dev); - vector alpha(uafp), beta(uafp); + auto uaTs = usm_allocator(cxt, *dev); + vector alpha(uaTs), beta(uaTs); m.resize(group_count); n.resize(group_count); @@ -104,9 +104,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { lda[i] = std::max(m[i], k[i]); ldb[i] = std::max(n[i], k[i]); ldc[i] = std::max(m[i], n[i]); - alpha[i] = rand_scalar(); - beta[i] = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { + alpha[i] = rand_scalar(); + beta[i] = rand_scalar(); + if ((std::is_same::value) || (std::is_same::value)) { transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); } @@ -125,12 +125,20 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { total_batch_count += group_size[i]; } - auto uafpp = usm_allocator(cxt, *dev); - vector a_array(uafpp), b_array(uafpp), c_array(uafpp), - c_ref_array(uafpp); + auto uaTap = usm_allocator(cxt, *dev); + auto uaTbp = usm_allocator(cxt, *dev); + auto uaTcp = usm_allocator(cxt, *dev); + auto uaTsp = usm_allocator(cxt, *dev); + vector a_array(uaTap); + vector b_array(uaTbp); + vector c_array(uaTcp), c_cast_ref_array(uaTcp); + vector a_ref_array(uaTsp), b_ref_array(uaTsp), c_ref_array(uaTsp); a_array.resize(total_batch_count); b_array.resize(total_batch_count); c_array.resize(total_batch_count); + a_ref_array.resize(total_batch_count); + b_ref_array.resize(total_batch_count); + c_cast_ref_array.resize(total_batch_count); c_ref_array.resize(total_batch_count); idx = 0; @@ -149,13 +157,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { default: break; } for (j = 0; j < group_size[i]; j++) { - a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); - b_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); - c_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); - c_ref_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); + a_array[idx] = (Ta *)oneapi::mkl::malloc_shared(64, sizeof(Ta) * size_a, *dev, cxt); + b_array[idx] = (Tb *)oneapi::mkl::malloc_shared(64, sizeof(Tb) * size_b, *dev, cxt); + c_array[idx] = (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + a_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_a, *dev, cxt); + b_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_b, *dev, cxt); + c_cast_ref_array[idx] = + (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + c_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_c, *dev, cxt); rand_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i]); rand_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i]); rand_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i]); + copy_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i], a_ref_array[idx]); + copy_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i], b_ref_array[idx]); copy_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_ref_array[idx]); idx++; @@ -163,7 +177,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } // Call reference GEMM_BATCH. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int *m_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *n_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *k_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); @@ -196,6 +210,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -216,9 +233,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size_ref[i]; j++) { ::gemm(convert_to_cblas_layout(layout), transa_ref[i], transb_ref[i], (const int *)&m_ref[i], (const int *)&n_ref[i], (const int *)&k_ref[i], - (const fp_ref *)&alpha[i], (const fp_ref *)a_array[idx], - (const int *)&lda_ref[i], (const fp_ref *)b_array[idx], (const int *)&ldb_ref[i], - (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], (const int *)&ldc_ref[i]); + (const fp_ref *)&alpha[i], (const fp_ref *)a_ref_array[idx], + (const int *)&lda_ref[i], (const fp_ref *)b_ref_array[idx], + (const int *)&ldb_ref[i], (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], + (const int *)&ldc_ref[i]); idx++; } } @@ -231,37 +249,37 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: done = oneapi::mkl::blas::row_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { case oneapi::mkl::layout::col_major: TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], + (const Ta **)&a_array[0], &lda[0], (const Ta **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -286,6 +304,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -299,11 +320,14 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. + constexpr int tol_scalar = std::is_same_v ? 10 : 40; idx = 0; for (i = 0; i < group_count; i++) { for (j = 0; j < group_size[i]; j++) { - good = good && check_equal_matrix(c_array[idx], c_ref_array[idx], layout, m[i], n[i], - ldc[i], 10 * k[i], std::cout); + copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], + ldc[i], c_cast_ref_array[idx]); + good = good && check_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], + n[i], ldc[i], tol_scalar * k[i], std::cout); idx++; } } @@ -322,6 +346,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -334,29 +361,49 @@ class GemmBatchUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, RealHalfRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, RealIntRealScalarPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, RealIntRealIntPrecision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests, diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index d8c7029b1..6001d65da 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -249,21 +249,21 @@ void copy_matrix(vec_src &src, oneapi::mkl::layout layout, oneapi::mkl::transpos } } -template -void copy_matrix(fp *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, int n, - int ld, fp *dest) { +template +void copy_matrix(fp_src *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, + int n, int ld, fp_dst *dest) { if (((trans == oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) for (int i = 0; i < m; i++) - dest[i + j * ld] = (fp)src[i + j * ld]; + dest[i + j * ld] = (fp_dst)src[i + j * ld]; } else { for (int i = 0; i < m; i++) for (int j = 0; j < n; j++) - dest[j + i * ld] = (fp)src[j + i * ld]; + dest[j + i * ld] = (fp_dst)src[j + i * ld]; } } From 1cd303ebb811daa3cf4799ce570a1f67349ad270 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:04:12 +0100 Subject: [PATCH 03/30] Add new gemm_batch dtypes to cuBlas --- .../oneapi/mkl/blas/detail/cublas/blas_ct.hxx | 105 +++++ .../blas/detail/cublas/onemkl_blas_cublas.hxx | 58 +++ src/blas/backends/cublas/cublas_batch.cpp | 360 ++++++++++-------- src/blas/backends/cublas/cublas_helper.hpp | 50 +++ src/blas/backends/cublas/cublas_wrappers.cpp | 18 + 5 files changed, 437 insertions(+), 154 deletions(-) diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index 65ae5b853..9483a66c1 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -186,6 +186,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2670,6 +2703,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2737,6 +2806,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index f94e09426..1141eb238 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -804,6 +804,25 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -2040,6 +2059,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, const float *b, @@ -2081,6 +2118,27 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index beefd6eeb..8106db5e9 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -140,16 +140,21 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -158,33 +163,53 @@ inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, tra auto c_acc = c.template get_access(cgh); onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, \ + lda, stride_a, b, ldb, stride_b, beta, c, \ + ldc, stride_c, batch_size); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -553,17 +578,23 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, + const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, + Tc *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -573,50 +604,71 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que } onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, \ + batch_size, dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -629,14 +681,14 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que int64_t offset = 0; cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - auto **c_ = reinterpret_cast(c); CUBLAS_ERROR_FUNC_T_SYNC( - func_name, func, err, handle, get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (cuDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); offset += group_size[i]; } }); @@ -644,21 +696,38 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM @@ -1066,30 +1135,25 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -1458,59 +1522,47 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_BATCH_LAUNCHER_USM diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 0ee9930e3..783b8c58f 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -231,6 +231,56 @@ inline cublasSideMode_t get_cublas_side_mode(oneapi::mkl::side lr) { } } +template +inline cudaDataType_t get_cublas_datatype() { + static_assert(false); +} + +template <> +inline cudaDataType_t get_cublas_datatype<__half>() { + return CUDA_R_16F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8U; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8U; +} + /*converting std::complex to cuComplex*/ /*converting sycl::half to __half*/ template diff --git a/src/blas/backends/cublas/cublas_wrappers.cpp b/src/blas/backends/cublas/cublas_wrappers.cpp index fe479e195..ee5c7239f 100644 --- a/src/blas/backends/cublas/cublas_wrappers.cpp +++ b/src/blas/backends/cublas/cublas_wrappers.cpp @@ -205,6 +205,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, @@ -460,6 +463,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, @@ -686,6 +695,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, @@ -941,6 +953,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, From 6fd86c8485ebfe2ee2907cf0748f891f552dc8bc Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:05:02 +0100 Subject: [PATCH 04/30] Add new gemm_batch dtypes to rocBlas --- .../mkl/blas/detail/rocblas/blas_ct.hxx | 99 +++++ .../detail/rocblas/onemkl_blas_rocblas.hxx | 54 +++ src/blas/backends/rocblas/rocblas_batch.cpp | 417 ++++++++++++------ src/blas/backends/rocblas/rocblas_helper.hpp | 60 +++ .../backends/rocblas/rocblas_wrappers.cpp | 18 + 5 files changed, 510 insertions(+), 138 deletions(-) diff --git a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx index 32188fed7..bc86929b0 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx @@ -181,6 +181,36 @@ void gemm_batch(backend_selector selector, transpose transa, t c, ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, sycl::buffer &a, int64_t lda, float beta, sycl::buffer &c, int64_t ldc) { @@ -2538,6 +2568,39 @@ sycl::event gemm_batch(backend_selector selector, transpose *t return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const sycl::half **a, int64_t *lda, const sycl::half **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, std::int32_t **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, @@ -2598,6 +2661,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const sycl::half *a, int64_t lda, int64_t stride_a, const sycl::half *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, std::int32_t *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, int64_t n, float alpha, const float *a, const float *x, int64_t incx, float beta, float *y, int64_t incy, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx index e4cd77c4a..70aabaaf9 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx @@ -744,6 +744,24 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t sycl::half beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -1848,6 +1866,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, sycl::half **c, int64_t *ldc, int64_t group_count, int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, float beta, float *c, @@ -1880,6 +1916,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i sycl::half beta, sycl::half *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, float beta, float *c, int64_t ldc, diff --git a/src/blas/backends/rocblas/rocblas_batch.cpp b/src/blas/backends/rocblas/rocblas_batch.cpp index 9a0a1be28..09cd01ce3 100644 --- a/src/blas/backends/rocblas/rocblas_batch.cpp +++ b/src/blas/backends/rocblas/rocblas_batch.cpp @@ -227,14 +227,20 @@ DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) #undef DGMM_STRIDED_BATCH_LAUNCHER -template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, T beta, - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { - using rocDataType = typename RocEquivalentType::Type; +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -242,32 +248,55 @@ inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpos onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stridea, b_, ldb, strideb, (rocDataType *)&beta, c_, - ldc, stridec, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ - sycl::buffer &b, int64_t ldb, int64_t strideb, TYPE beta, \ - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, \ - ldb, strideb, beta, c, ldc, stridec, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -816,63 +845,97 @@ DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) #undef DGMM_BATCH_LAUNCHER -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stridea, const T *b, int64_t ldb, int64_t strideb, T beta, - T *c, int64_t ldc, int64_t stridec, int64_t batch_size, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType::Type; +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stridea, b_, ldb, strideb, (rocDataType *)&beta, c_, - ldc, stridec, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); }); }); return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stridea, const TYPE *b, int64_t ldb, int64_t strideb, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stridec, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ - b, ldb, strideb, beta, c, ldc, stridec, batch_size, dependencies); \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType::Type; +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { @@ -881,14 +944,18 @@ inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, int64_t offset = 0; rocblas_status err; for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - auto **c_ = reinterpret_cast(c); + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); + auto **c_ = reinterpret_cast(c); ROCBLAS_ERROR_FUNC_SYNC( - func, err, handle, get_rocblas_operation(transa[i]), - get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (rocDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + rocblas_gemm_batched_ex, err, handle, get_rocblas_operation(transa[i]), + get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], + a_ + offset, get_rocblas_datatype(), (int)lda[i], b_ + offset, + get_rocblas_datatype(), (int)ldb[i], &beta[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], (int)group_size[i], + get_rocblas_datatype(), rocblas_gemm_algo_standard, solution_index, + flags); offset += group_size[i]; } }); @@ -897,21 +964,38 @@ inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM @@ -1442,32 +1526,52 @@ DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) #undef DGMM_STRIDED_BATCH_LAUNCHER -template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, T beta, - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { auto new_transa = transb; auto new_transb = transa; - column_major::gemm_batch(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, - a, lda, stridea, beta, c, ldc, stridec, batch_size); + column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, a, lda, + stridea, beta, c, ldc, stridec, batch_size); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ - sycl::buffer &b, int64_t ldb, int64_t strideb, TYPE beta, \ - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, \ - ldb, strideb, beta, c, ldc, stridec, batch_size); \ +#undef GEMM_STRIDED_BATCH_LAUNCHER +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -1936,67 +2040,104 @@ DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) #undef DGMM_BATCH_LAUNCHER -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stridea, const T *b, int64_t ldb, int64_t strideb, T beta, - T *c, int64_t ldc, int64_t stridec, int64_t batch_size, - const std::vector &dependencies) { +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { auto new_transa = transb; auto new_transb = transa; - return column_major::gemm_batch(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, - strideb, a, lda, stridea, beta, c, ldc, stridec, batch_size, + return column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, + a, lda, stridea, beta, c, ldc, stridec, batch_size, dependencies); } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stridea, const TYPE *b, int64_t ldb, int64_t strideb, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stridec, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ - b, ldb, strideb, beta, c, ldc, stridec, batch_size, dependencies); \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { for (int64_t i = 0; i < group_count; i++) { std::swap(transa[i], transb[i]); } - return column_major::gemm_batch(func, queue, transa, transb, n, m, k, alpha, b, ldb, a, lda, - beta, c, ldc, group_count, group_size, dependencies); + return column_major::gemm_batch(queue, transa, transb, n, m, k, alpha, b, ldb, a, lda, beta, c, + ldc, group_count, group_size, dependencies); } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ + } + +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for data type combination"); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index eeeb5a11c..ad3544200 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -205,6 +205,66 @@ inline rocblas_side get_rocblas_side_mode(oneapi::mkl::side lr) { } } +template +inline rocblas_datatype get_rocblas_datatype() { + static_assert(false); +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_bf16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype>() { + return rocblas_datatype_bf16_c; +} + /*converting std::complex to roc__complex sycl::half to rocblas_half*/ template diff --git a/src/blas/backends/rocblas/rocblas_wrappers.cpp b/src/blas/backends/rocblas/rocblas_wrappers.cpp index 87fc78b86..ce4c92da5 100644 --- a/src/blas/backends/rocblas/rocblas_wrappers.cpp +++ b/src/blas/backends/rocblas/rocblas_wrappers.cpp @@ -207,6 +207,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, @@ -462,6 +465,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, @@ -688,6 +697,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, @@ -943,6 +955,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, From 02af09c22bfe29786d62157e485f29f401e79bc9 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:16:44 +0100 Subject: [PATCH 05/30] Add new gemm_batch dtypes to mklcpu/gpu --- .../oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx | 105 ++++++++++++++++++ .../oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx | 105 ++++++++++++++++++ src/blas/backends/mkl_common/mkl_batch.cxx | 81 ++++++++++++++ 3 files changed, 291 insertions(+) diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 004a4c11c..1724bf5c7 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2672,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2739,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index d365a39c4..c69257e9c 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2622,6 +2655,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose *transa, transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, const float **b, @@ -2685,6 +2754,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, diff --git a/src/blas/backends/mkl_common/mkl_batch.cxx b/src/blas/backends/mkl_common/mkl_batch.cxx index 0a204d5b7..412d3c990 100644 --- a/src/blas/backends/mkl_common/mkl_batch.cxx +++ b/src/blas/backends/mkl_common/mkl_batch.cxx @@ -182,6 +182,33 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t stride_b, beta, c, ldc, stride_c, batch_size); } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -642,6 +669,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, const float **a, int64_t *lda, const float **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, @@ -689,6 +743,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, ldc, group_count, groupsize, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, From 76d93d43b800aa69bcd130e484ae846b603c3e01 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:17:41 +0100 Subject: [PATCH 06/30] Add gemm_batch dtypes (unimplemented) --- .../mkl/blas/detail/portblas/blas_ct.hxx | 105 ++++++++++++++++++ src/blas/backends/portblas/portblas_batch.cxx | 84 ++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx index 2a092a61b..8a66ed707 100644 --- a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx @@ -187,6 +187,39 @@ void gemm_batch(backend_selector selector, transpose transa, c, ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2576,6 +2609,42 @@ sycl::event gemm_batch(backend_selector selector, transpose * return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose *transa, transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, const float **b, @@ -2638,6 +2707,42 @@ sycl::event gemm_batch(backend_selector selector, transpose t return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index 581fcd2e5..cf49f0dea 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -213,6 +213,33 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: throw unimplemented("blas", "gemm_batch", " for complex"); } +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); +} + void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, @@ -700,6 +727,33 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, throw unimplemented("blas", "gemm_batch", " for USM"); } +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, @@ -754,6 +808,36 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, throw unimplemented("blas", "gemm_batch", " for USM"); } +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, + std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, float alpha, From 07a727bd634817df8d77867b1674a61302d152f7 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 2 Apr 2024 11:18:14 +0100 Subject: [PATCH 07/30] Add gemm_batch dtypes to netlib (unimplemented) --- .../oneapi/mkl/blas/detail/netlib/blas_ct.hxx | 105 ++++++++++++++++ src/blas/backends/netlib/netlib_batch.cxx | 117 ++++++++++++++++++ 2 files changed, 222 insertions(+) diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index fe5b56b48..404d79ae0 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2672,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2739,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/src/blas/backends/netlib/netlib_batch.cxx b/src/blas/backends/netlib/netlib_batch.cxx index a029a60bc..7a2839dd4 100644 --- a/src/blas/backends/netlib/netlib_batch.cxx +++ b/src/blas/backends/netlib/netlib_batch.cxx @@ -279,6 +279,45 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -983,6 +1022,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, @@ -1052,6 +1130,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, From 58a667a3448bd9bb33b9f99a00f8154b28aae103 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 12:00:45 +0100 Subject: [PATCH 08/30] Fix typo --- src/blas/backends/cublas/cublas_helper.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 783b8c58f..1380fed4f 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -278,7 +278,7 @@ inline cudaDataType_t get_cublas_datatype() { template <> inline cudaDataType_t get_cublas_datatype() { - return CUDA_R_8U; + return CUDA_R_32U; } /*converting std::complex to cuComplex*/ From 9aac33bc947a252216af908d3f4f72f6ec955f37 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 12:01:07 +0100 Subject: [PATCH 09/30] Fix spelling --- src/blas/backends/portblas/portblas_batch.cxx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index cf49f0dea..28c7ee5dc 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -219,7 +219,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -228,7 +228,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -237,7 +237,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupporeted dtype"); + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); } void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, From f468f2bb15e176da5dd1d5e6e00f8e7373d35a76 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 12:01:52 +0100 Subject: [PATCH 10/30] Change naming convention --- src/blas/blas_loader.cpp | 36 ++++++++++++++++++------------------ src/blas/function_table.hpp | 36 ++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 9022900fc..c1f1339c6 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -1348,7 +1348,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].column_major_hsgemm_batch_strided_sycl( + function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -1359,7 +1359,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].column_major_isgemm_batch_strided_sycl( + function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -1370,7 +1370,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].column_major_iigemm_batch_strided_sycl( + function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -3444,7 +3444,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_hsgemm_batch_group_usm_sycl( + return function_tables[libkey].column_major_gemm_f16f16f32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -3455,7 +3455,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_isgemm_batch_group_usm_sycl( + return function_tables[libkey].column_major_gemm_s8s8f32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -3466,7 +3466,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_iigemm_batch_group_usm_sycl( + return function_tables[libkey].column_major_gemm_s8s8s32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -3535,7 +3535,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_hsgemm_batch_strided_usm_sycl( + return function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } @@ -3546,7 +3546,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_isgemm_batch_strided_usm_sycl( + return function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } @@ -3557,7 +3557,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].column_major_iigemm_batch_strided_usm_sycl( + return function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } @@ -5282,7 +5282,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].row_major_hsgemm_batch_strided_sycl( + function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -5293,7 +5293,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].row_major_isgemm_batch_strided_sycl( + function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -5304,7 +5304,7 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - function_tables[libkey].row_major_iigemm_batch_strided_sycl( + function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } @@ -7374,7 +7374,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_hsgemm_batch_group_usm_sycl( + return function_tables[libkey].row_major_gemm_f16f16f32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -7385,7 +7385,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_isgemm_batch_group_usm_sycl( + return function_tables[libkey].row_major_gemm_s8s8f32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -7396,7 +7396,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_iigemm_batch_group_usm_sycl( + return function_tables[libkey].row_major_gemm_s8s8s32_batch_group_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); } @@ -7465,7 +7465,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_hsgemm_batch_strided_usm_sycl( + return function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } @@ -7476,7 +7476,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_isgemm_batch_strided_usm_sycl( + return function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } @@ -7487,7 +7487,7 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - return function_tables[libkey].row_major_iigemm_batch_strided_usm_sycl( + return function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_usm_sycl( queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 57490523e..2e28661e7 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -869,20 +869,20 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*column_major_hsgemm_batch_strided_sycl)( + void (*column_major_gemm_f16f16f32_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*column_major_isgemm_batch_strided_sycl)( + void (*column_major_gemm_s8s8f32_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*column_major_iigemm_batch_strided_sycl)( + void (*column_major_gemm_s8s8s32_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -2200,19 +2200,19 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*column_major_hsgemm_batch_group_usm_sycl)( + sycl::event (*column_major_gemm_f16f16f32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*column_major_isgemm_batch_group_usm_sycl)( + sycl::event (*column_major_gemm_s8s8f32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*column_major_iigemm_batch_group_usm_sycl)( + sycl::event (*column_major_gemm_s8s8s32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, @@ -2251,19 +2251,19 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*column_major_hsgemm_batch_strided_usm_sycl)( + sycl::event (*column_major_gemm_f16f16f32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*column_major_isgemm_batch_strided_usm_sycl)( + sycl::event (*column_major_gemm_s8s8f32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*column_major_iigemm_batch_strided_usm_sycl)( + sycl::event (*column_major_gemm_s8s8s32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, @@ -3325,13 +3325,13 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*row_major_hsgemm_batch_strided_sycl)( + void (*row_major_gemm_f16f16f32_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*row_major_isgemm_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, + void (*row_major_gemm_s8s8f32_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -3340,7 +3340,7 @@ typedef struct { std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*row_major_iigemm_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, + void (*row_major_gemm_s8s8s32_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -4661,19 +4661,19 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*row_major_hsgemm_batch_group_usm_sycl)( + sycl::event (*row_major_gemm_f16f16f32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*row_major_isgemm_batch_group_usm_sycl)( + sycl::event (*row_major_gemm_s8s8f32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); - sycl::event (*row_major_iigemm_batch_group_usm_sycl)( + sycl::event (*row_major_gemm_s8s8s32_batch_group_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, @@ -4712,19 +4712,19 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*row_major_hsgemm_batch_strided_usm_sycl)( + sycl::event (*row_major_gemm_f16f16f32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*row_major_isgemm_batch_strided_usm_sycl)( + sycl::event (*row_major_gemm_s8s8f32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); - sycl::event (*row_major_iigemm_batch_strided_usm_sycl)( + sycl::event (*row_major_gemm_s8s8s32_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, From 6b45b0c74dc2372a08eb0485f1c17ae356d946c5 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 12:02:24 +0100 Subject: [PATCH 11/30] Update tests --- .../blas/batch/gemm_batch_stride.cpp | 21 +++++++++++-------- .../blas/batch/gemm_batch_stride_usm.cpp | 12 +++++------ .../unit_tests/blas/batch/gemm_batch_usm.cpp | 9 ++++---- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 76b477181..cac90a0b3 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -66,11 +66,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { alpha = rand_scalar(); beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + if ((std::is_same>::value) || (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -81,6 +77,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { transb = oneapi::mkl::transpose::conjtrans; else transb = (oneapi::mkl::transpose)tmp; + } else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); } int64_t stride_a, stride_b, stride_c; @@ -112,12 +111,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { rand_matrix(C.data() + stride_c * i, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); } - for (size_t i = 0; i < A.size(); ++i) + for (size_t i = 0; i < A.size(); ++i) { A_ref[i] = A[i]; - for (size_t i = 0; i < B.size(); ++i) + } + for (size_t i = 0; i < B.size(); ++i) { B_ref[i] = B[i]; - for (size_t i = 0; i < C.size(); ++i) + } + for (size_t i = 0; i < C.size(); ++i) { C_ref[i] = C[i]; + } // Call reference GEMM_BATCH_STRIDE. using fp_ref = typename ref_type_info::type; @@ -213,8 +215,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. constexpr int tol_scalar = std::is_same_v ? 10 : 40; - for (size_t i = 0; i < C_ref.size(); ++i) + for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; + } auto C_accessor = C_buffer.template get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, stride_c * batch_size, tol_scalar * k, diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 16959a9cd..9436146d2 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -85,11 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { ldc = std::max(m, n); alpha = rand_scalar(); beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + if ((std::is_same>::value) || (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -100,6 +96,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { transb = oneapi::mkl::transpose::conjtrans; else transb = (oneapi::mkl::transpose)tmp; + } else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); } int64_t stride_a, stride_b, stride_c; @@ -247,8 +246,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. constexpr int tol_scalar = std::is_same_v ? 10 : 40; - for (size_t i = 0; i < C_ref.size(); ++i) + for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; + } bool good = check_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, stride_c * batch_size, tol_scalar * k, std::cout); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index af38a10b0..b6e2ad5e0 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -106,11 +106,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { ldc[i] = std::max(m[i], n[i]); alpha[i] = rand_scalar(); beta[i] = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); - transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + if ((std::is_same>::value) || (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa[i] = oneapi::mkl::transpose::conjtrans; @@ -121,6 +117,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { transb[i] = oneapi::mkl::transpose::conjtrans; else transb[i] = (oneapi::mkl::transpose)tmp; + } else { + transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); + transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); } total_batch_count += group_size[i]; } From 0a198be68a5efbadaa9407ad5b4c9d7f75e1b22d Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 16:37:52 +0100 Subject: [PATCH 12/30] Add more descriptive throw --- src/blas/backends/cublas/cublas_batch.cpp | 15 ++++++++-- src/blas/backends/cublas/cublas_helper.hpp | 1 + src/blas/backends/rocblas/rocblas_batch.cpp | 30 ++++++++++++++++---- src/blas/backends/rocblas/rocblas_helper.hpp | 1 + 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 8106db5e9..f92d7bd0e 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -206,7 +206,10 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::com int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) @@ -645,7 +648,10 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std: int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) @@ -724,7 +730,10 @@ GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ int64_t group_count, int64_t *group_size, \ const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for unimplmented dtypes"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 1380fed4f..380534309 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -35,6 +35,7 @@ #include "oneapi/mkl/types.hpp" #include "runtime_support_helper.hpp" +#include "error_helper.hpp" namespace oneapi { namespace mkl { diff --git a/src/blas/backends/rocblas/rocblas_batch.cpp b/src/blas/backends/rocblas/rocblas_batch.cpp index 09cd01ce3..5fa103055 100644 --- a/src/blas/backends/rocblas/rocblas_batch.cpp +++ b/src/blas/backends/rocblas/rocblas_batch.cpp @@ -292,7 +292,10 @@ GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) @@ -912,7 +915,10 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) @@ -991,7 +997,10 @@ GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ int64_t group_count, int64_t *group_size, \ const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) @@ -1567,7 +1576,10 @@ GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) @@ -2084,7 +2096,10 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) @@ -2133,7 +2148,10 @@ GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ int64_t group_count, int64_t *group_size, \ const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for data type combination"); \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index ad3544200..cbc883973 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -31,6 +31,7 @@ #include #include "oneapi/mkl/types.hpp" #include +#include "error_helper.hpp" namespace oneapi { namespace mkl { From f9d2b85a0a94505b589fed23f028e76eff118fbb Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 8 Apr 2024 16:38:05 +0100 Subject: [PATCH 13/30] Clang-foramt --- src/blas/function_table.hpp | 32 ++++++++----------- .../blas/batch/gemm_batch_stride.cpp | 6 ++-- .../blas/batch/gemm_batch_stride_usm.cpp | 6 ++-- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 6 ++-- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 2e28661e7..a242fd0c0 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -3331,24 +3331,20 @@ typedef struct { std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); - void (*row_major_gemm_s8s8f32_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, - oneapi::mkl::transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, - sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, - sycl::buffer &b, std::int64_t ldb, - std::int64_t stride_b, float beta, - sycl::buffer &c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size); - void (*row_major_gemm_s8s8s32_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose transa, - oneapi::mkl::transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, - sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, - sycl::buffer &b, std::int64_t ldb, - std::int64_t stride_b, float beta, - sycl::buffer &c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_gemm_s8s8f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*row_major_gemm_s8s8s32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*row_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index cac90a0b3..5982c88bf 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -66,7 +66,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { alpha = rand_scalar(); beta = rand_scalar(); - if ((std::is_same>::value) || (std::is_same>::value)) { + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -77,7 +78,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { transb = oneapi::mkl::transpose::conjtrans; else transb = (oneapi::mkl::transpose)tmp; - } else { + } + else { transa = (oneapi::mkl::transpose)(std::rand() % 2); transb = (oneapi::mkl::transpose)(std::rand() % 2); } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 9436146d2..7329fa09f 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -85,7 +85,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { ldc = std::max(m, n); alpha = rand_scalar(); beta = rand_scalar(); - if ((std::is_same>::value) || (std::is_same>::value)) { + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -96,7 +97,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { transb = oneapi::mkl::transpose::conjtrans; else transb = (oneapi::mkl::transpose)tmp; - } else { + } + else { transa = (oneapi::mkl::transpose)(std::rand() % 2); transb = (oneapi::mkl::transpose)(std::rand() % 2); } diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index b6e2ad5e0..fb694a37c 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -106,7 +106,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { ldc[i] = std::max(m[i], n[i]); alpha[i] = rand_scalar(); beta[i] = rand_scalar(); - if ((std::is_same>::value) || (std::is_same>::value)) { + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa[i] = oneapi::mkl::transpose::conjtrans; @@ -117,7 +118,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { transb[i] = oneapi::mkl::transpose::conjtrans; else transb[i] = (oneapi::mkl::transpose)tmp; - } else { + } + else { transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); } From a4d7916ec394144a980e48ea9d68fd03ef2f4098 Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 9 Apr 2024 14:02:12 +0100 Subject: [PATCH 14/30] Fix allocator msitake --- tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 7329fa09f..b68916e5a 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -125,8 +125,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { auto us = usm_allocator(cxt, *dev); vector A(ua); vector B(ub); - vector C(uc), C_cast_ref(us); - vector A_ref(ua), B_ref(ub), C_ref(us); + vector C(uc), C_cast_ref(uc); + vector A_ref(us), B_ref(us), C_ref(us); A.resize(stride_a * batch_size); B.resize(stride_b * batch_size); From f0e3004c1c032f983ee29347a9b5a9c0d615583b Mon Sep 17 00:00:00 2001 From: Aidan Date: Thu, 11 Apr 2024 14:02:25 +0100 Subject: [PATCH 15/30] Add check matrix instantiation --- .../blas/batch/gemm_batch_stride.cpp | 58 +++++++++++++++++-- .../blas/batch/gemm_batch_stride_usm.cpp | 50 ++++++++++++++-- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 38 +++++++++++- 3 files changed, 135 insertions(+), 11 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 5982c88bf..391c68a05 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -45,6 +45,54 @@ using std::vector; extern std::vector devices; +template +typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, + int error_mag) { + return (std::abs(x - x_ref) <= 1); +} + +template +struct acc_type { + typedef host_accessor type; +}; + +template +using acc_type_t = typename acc_type::type; + +template +struct vec_type { + typedef vector> type; +}; + +template +using vec_type_t = typename vec_type::type; + +// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding +template <> +bool check_equal_matrix, vec_type_t>(acc_type_t &M, + vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, + int n, int ld, int error_mag, + std::ostream &out) { + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + namespace { template @@ -215,15 +263,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - constexpr int tol_scalar = std::is_same_v ? 10 : 40; + int tol_scalar = std::is_same_v ? 10 : 60; + if (main_queue.get_device().is_cpu()) + tol_scalar = 100; for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; } auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, - stride_c * batch_size, 1, stride_c * batch_size, tol_scalar * k, - std::cout); + bool good = check_equal_matrix, vec_type_t>( + C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, + stride_c * batch_size, tol_scalar * k, std::cout); return (int)good; } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index b68916e5a..7b946aefa 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -45,6 +45,46 @@ using std::vector; extern std::vector devices; +template +typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, + int error_mag) { + return (std::abs(x - x_ref) <= 1); +} + +template +struct vec_type { + typedef vector> type; +}; + +template +using vec_type_t = typename vec_type::type; + +// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding +template <> +bool check_equal_matrix, vec_type_t>(vec_type_t &M, + vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, + int n, int ld, int error_mag, + std::ostream &out) { + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + namespace { template @@ -246,14 +286,16 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - constexpr int tol_scalar = std::is_same_v ? 10 : 40; + int tol_scalar = std::is_same_v ? 10 : 60; + if (main_queue.get_device().is_cpu()) + tol_scalar = 100; for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; } - bool good = - check_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, - stride_c * batch_size, tol_scalar * k, std::cout); + bool good = check_equal_matrix, vec_type_t>( + C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, + stride_c * batch_size, tol_scalar * k, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index fb694a37c..99e376a9f 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -45,6 +45,35 @@ using std::vector; extern std::vector devices; +template +typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, + int error_mag) { + return (std::abs(x - x_ref) <= 1); +} + +// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding +template <> +bool check_equal_matrix(const int32_t *M, const int32_t *M_ref, oneapi::mkl::layout layout, + int m, int n, int ld, int error_mag, std::ostream &out) { + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + namespace { template @@ -321,14 +350,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. - constexpr int tol_scalar = std::is_same_v ? 10 : 40; + int tol_scalar = std::is_same_v ? 10 : 60; + if (main_queue.get_device().is_cpu()) + tol_scalar = 100; + idx = 0; for (i = 0; i < group_count; i++) { for (j = 0; j < group_size[i]; j++) { copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_cast_ref_array[idx]); - good = good && check_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], - n[i], ldc[i], tol_scalar * k[i], std::cout); + good = good && check_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], + n[i], ldc[i], tol_scalar * k[i], std::cout); idx++; } } From e83fa1eb12f074657fb3e9a4a056667ecaa5f905 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 13 May 2024 16:32:20 +0100 Subject: [PATCH 16/30] Add src/include to rocBlas include path --- src/blas/backends/rocblas/CMakeLists.txt | 1 + src/include/error_helper.hpp | 46 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 src/include/error_helper.hpp diff --git a/src/blas/backends/rocblas/CMakeLists.txt b/src/blas/backends/rocblas/CMakeLists.txt index 3a71eda1c..76dc126ad 100644 --- a/src/blas/backends/rocblas/CMakeLists.txt +++ b/src/blas/backends/rocblas/CMakeLists.txt @@ -39,6 +39,7 @@ add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src ${PROJECT_BINARY_DIR}/bin ${ONEMKL_GENERATED_INCLUDE_PATH} diff --git a/src/include/error_helper.hpp b/src/include/error_helper.hpp new file mode 100644 index 000000000..136fc5616 --- /dev/null +++ b/src/include/error_helper.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_ERROR_HELPER_HPP_ +#define _ONEMKL_ERROR_HELPER_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +template +inline const std::string dtype_string(); +template <> +inline const std::string dtype_string() {return "float";} +template <> +inline const std::string dtype_string() {return "double";} +template <> +inline const std::string dtype_string() {return "half";} +template <> +inline const std::string dtype_string>() {return "complex";} +template <> +inline const std::string dtype_string>() {return "complex";} +template <> +inline const std::string dtype_string() {return "int32";} +template <> +inline const std::string dtype_string() {return "int8";} + +#endif //_ONEMKL_ERROR_HELPER_HPP_ From 2f5fa024003c0a50f40c6e7b9572c3d9da4956e3 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 13 May 2024 16:33:18 +0100 Subject: [PATCH 17/30] Increase tolerancing for int8_t inputs --- tests/unit_tests/blas/batch/gemm_batch_usm.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 99e376a9f..0fb8e3025 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -350,9 +350,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. - int tol_scalar = std::is_same_v ? 10 : 60; - if (main_queue.get_device().is_cpu()) - tol_scalar = 100; + int tol_scalar = 10; + // Scale the tolerance for when we generate int8_t, as input range is [-128, 127] + // rather than [-1,1] + if (std::is_same_v && std::is_same_v) + tol_scalar *= 256; idx = 0; for (i = 0; i < group_count; i++) { From bcc3a2926e5bfe682d15f2d25f4836cc710f55ae Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 13 May 2024 16:46:27 +0100 Subject: [PATCH 18/30] Change test names --- tests/unit_tests/blas/batch/gemm_batch_stride.cpp | 6 +++--- tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp | 6 +++--- tests/unit_tests/blas/batch/gemm_batch_usm.cpp | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 391c68a05..6e23a4aa0 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -286,17 +286,17 @@ TEST_P(GemmBatchStrideTests, RealHalfPrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideTests, RealHalfRealScalarPrecision) { +TEST_P(GemmBatchStrideTests, HalfHalfFloatPrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideTests, RealIntRealScalarPrecision) { +TEST_P(GemmBatchStrideTests, Int8Int8SinglePrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideTests, RealIntPrecision) { +TEST_P(GemmBatchStrideTests, Int8Int8Int32Precision) { EXPECT_TRUEORSKIP((test( std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 7b946aefa..84e465144 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -313,17 +313,17 @@ TEST_P(GemmBatchStrideUsmTests, RealHalfPrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideUsmTests, RealHalfRealScalarPrecision) { +TEST_P(GemmBatchStrideUsmTests, HalfHalfFloatPrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideUsmTests, RealIntRealScalarPrecision) { +TEST_P(GemmBatchStrideUsmTests, Int8Int8SinglePrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchStrideUsmTests, RealIntRealIntPrecision) { +TEST_P(GemmBatchStrideUsmTests, Int8Int8Int32Precision) { EXPECT_TRUEORSKIP((test( std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 0fb8e3025..cba7da024 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -400,17 +400,17 @@ TEST_P(GemmBatchUsmTests, RealHalfPrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchUsmTests, RealHalfRealScalarPrecision) { +TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchUsmTests, RealIntRealScalarPrecision) { +TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } -TEST_P(GemmBatchUsmTests, RealIntRealIntPrecision) { +TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) { EXPECT_TRUEORSKIP((test( std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } From 5963b396e6cb24b4241eb622a25e199ea4dd5785 Mon Sep 17 00:00:00 2001 From: Aidan Date: Mon, 27 May 2024 14:25:46 +0100 Subject: [PATCH 19/30] Rename error_helper to dtype_string --- src/blas/backends/cublas/cublas_helper.hpp | 2 +- src/blas/backends/rocblas/rocblas_helper.hpp | 2 +- src/include/{error_helper.hpp => dtype_string.hpp} | 6 +----- 3 files changed, 3 insertions(+), 7 deletions(-) rename src/include/{error_helper.hpp => dtype_string.hpp} (94%) diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 380534309..0fe7e7c5a 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -35,7 +35,7 @@ #include "oneapi/mkl/types.hpp" #include "runtime_support_helper.hpp" -#include "error_helper.hpp" +#include "dtype_string.hpp" namespace oneapi { namespace mkl { diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index cbc883973..ae6301a7a 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -31,7 +31,7 @@ #include #include "oneapi/mkl/types.hpp" #include -#include "error_helper.hpp" +#include "dtype_string.hpp" namespace oneapi { namespace mkl { diff --git a/src/include/error_helper.hpp b/src/include/dtype_string.hpp similarity index 94% rename from src/include/error_helper.hpp rename to src/include/dtype_string.hpp index 136fc5616..79787b9d0 100644 --- a/src/include/error_helper.hpp +++ b/src/include/dtype_string.hpp @@ -20,11 +20,7 @@ #ifndef _ONEMKL_ERROR_HELPER_HPP_ #define _ONEMKL_ERROR_HELPER_HPP_ -#if __has_include() -#include -#else -#include -#endif +#include template inline const std::string dtype_string(); From 6a3860f192d340d887e2006b2cb998608b8bcae8 Mon Sep 17 00:00:00 2001 From: Aidan Date: Fri, 31 May 2024 10:32:53 +0100 Subject: [PATCH 20/30] Disable int8, gemm_batch for cublas --- src/blas/backends/cublas/cublas_batch.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index f92d7bd0e..009bb9541 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -190,7 +190,6 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) -GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, @@ -212,6 +211,7 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::com dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -632,7 +632,6 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, @@ -654,6 +653,7 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std: dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM @@ -714,7 +714,6 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) -GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, @@ -736,6 +735,7 @@ GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM From a00af3ec1a63a3ce432e6662cf3123a84c61644e Mon Sep 17 00:00:00 2001 From: Aidan Date: Wed, 5 Jun 2024 10:04:07 +0100 Subject: [PATCH 21/30] Make check_equal_matrix static --- tests/unit_tests/blas/batch/gemm_batch_stride.cpp | 8 +++----- .../blas/batch/gemm_batch_stride_usm.cpp | 8 +++----- tests/unit_tests/blas/batch/gemm_batch_usm.cpp | 5 +++-- tests/unit_tests/blas/include/test_common.hpp | 14 +++++++------- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 6e23a4aa0..7b48f74a1 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -69,11 +69,9 @@ using vec_type_t = typename vec_type::type; // Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding template <> -bool check_equal_matrix, vec_type_t>(acc_type_t &M, - vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, - int n, int ld, int error_mag, - std::ostream &out) { +static bool check_equal_matrix, vec_type_t>( + acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 84e465144..5f3410683 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -61,11 +61,9 @@ using vec_type_t = typename vec_type::type; // Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding template <> -bool check_equal_matrix, vec_type_t>(vec_type_t &M, - vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, - int n, int ld, int error_mag, - std::ostream &out) { +static bool check_equal_matrix, vec_type_t>( + vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index cba7da024..cbe062b42 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -53,8 +53,9 @@ typename std::enable_if::value, bool>::type check_equal_int // Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding template <> -bool check_equal_matrix(const int32_t *M, const int32_t *M_ref, oneapi::mkl::layout layout, - int m, int n, int ld, int error_mag, std::ostream &out) { +static bool check_equal_matrix(const int32_t *M, const int32_t *M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 6001d65da..c90ffd813 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -564,8 +564,8 @@ bool check_equal_trsv_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag } template -bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, int ld, - int error_mag, std::ostream &out) { +static bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -586,8 +586,8 @@ bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, } template -bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { +static bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout, int m, + int n, int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -608,9 +608,9 @@ bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout } template -bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, - oneapi::mkl::uplo upper_lower, int m, int n, int ld, int error_mag, - std::ostream &out) { +static bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, + oneapi::mkl::uplo upper_lower, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { From 6588324116cdbe1c765c3f057d7adf20263d9551 Mon Sep 17 00:00:00 2001 From: Aidan Date: Wed, 5 Jun 2024 10:05:19 +0100 Subject: [PATCH 22/30] Format dtype_string --- src/include/dtype_string.hpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/include/dtype_string.hpp b/src/include/dtype_string.hpp index 79787b9d0..6f2a87feb 100644 --- a/src/include/dtype_string.hpp +++ b/src/include/dtype_string.hpp @@ -25,18 +25,32 @@ template inline const std::string dtype_string(); template <> -inline const std::string dtype_string() {return "float";} +inline const std::string dtype_string() { + return "float"; +} template <> -inline const std::string dtype_string() {return "double";} +inline const std::string dtype_string() { + return "double"; +} template <> -inline const std::string dtype_string() {return "half";} +inline const std::string dtype_string() { + return "half"; +} template <> -inline const std::string dtype_string>() {return "complex";} +inline const std::string dtype_string>() { + return "complex"; +} template <> -inline const std::string dtype_string>() {return "complex";} +inline const std::string dtype_string>() { + return "complex"; +} template <> -inline const std::string dtype_string() {return "int32";} +inline const std::string dtype_string() { + return "int32"; +} template <> -inline const std::string dtype_string() {return "int8";} +inline const std::string dtype_string() { + return "int8"; +} #endif //_ONEMKL_ERROR_HELPER_HPP_ From 36419060d0bae8db30880d21d9e76e760ac810c3 Mon Sep 17 00:00:00 2001 From: Aidan Date: Wed, 5 Jun 2024 10:18:29 +0100 Subject: [PATCH 23/30] Set gemm_batch int8, float to unimplemented --- src/blas/backends/mkl_common/mkl_batch.cxx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/blas/backends/mkl_common/mkl_batch.cxx b/src/blas/backends/mkl_common/mkl_batch.cxx index 412d3c990..6358a3922 100644 --- a/src/blas/backends/mkl_common/mkl_batch.cxx +++ b/src/blas/backends/mkl_common/mkl_batch.cxx @@ -196,8 +196,8 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, - stride_b, beta, c, ldc, stride_c, batch_size); + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); } void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, @@ -683,8 +683,8 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, - stride_b, beta, c, ldc, stride_c, batch_size, dependencies); + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); } sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, @@ -757,8 +757,8 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, group_count, groupsize, dependencies); + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); } sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, From 7dee1c114db61a7f85f7f87ff8f1e5952a8bcf06 Mon Sep 17 00:00:00 2001 From: Aidan Date: Wed, 5 Jun 2024 12:11:47 +0100 Subject: [PATCH 24/30] Add exceptions header --- src/blas/backends/mklcpu/mklcpu_batch.cpp | 1 + src/blas/backends/mklgpu/mklgpu_batch.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/src/blas/backends/mklcpu/mklcpu_batch.cpp b/src/blas/backends/mklcpu/mklcpu_batch.cpp index 9dd231629..5ecf4cc69 100644 --- a/src/blas/backends/mklcpu/mklcpu_batch.cpp +++ b/src/blas/backends/mklcpu/mklcpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/mklgpu/mklgpu_batch.cpp b/src/blas/backends/mklgpu/mklgpu_batch.cpp index d859a3b78..bad2db82c 100644 --- a/src/blas/backends/mklgpu/mklgpu_batch.cpp +++ b/src/blas/backends/mklgpu/mklgpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklgpu/onemkl_blas_mklgpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { From c31b635a4c3a81958f3cd6afdbb8f8fcabec01a1 Mon Sep 17 00:00:00 2001 From: Aidan Date: Fri, 14 Jun 2024 16:07:32 +0100 Subject: [PATCH 25/30] Remove check_equal_matrix specialization --- .../blas/batch/gemm_batch_stride.cpp | 25 ++++++++++++------- .../blas/batch/gemm_batch_stride_usm.cpp | 24 +++++++++++------- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 23 +++++++++++------ 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 7b48f74a1..f92a64569 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -45,6 +45,8 @@ using std::vector; extern std::vector devices; +namespace { + template typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, int error_mag) { @@ -67,11 +69,10 @@ struct vec_type { template using vec_type_t = typename vec_type::type; -// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding -template <> -static bool check_equal_matrix, vec_type_t>( - acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { +// Check for int32_t and Ts=float as small differences in the reference become large after rounding +inline bool check_equal_matrix_int(acc_type_t &M, vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -91,7 +92,13 @@ static bool check_equal_matrix, vec_type_t>( return good; } -namespace { +template +inline bool check_mat(acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, + int n, int ld, int error_mag, std::ostream &out) { + if constexpr (std::is_same::value) + return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); +} template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { @@ -269,9 +276,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { C_cast_ref[i] = C_ref[i]; } auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = check_equal_matrix, vec_type_t>( - C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, - stride_c * batch_size, tol_scalar * k, std::cout); + bool good = + check_mat(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, + 1, stride_c * batch_size, tol_scalar * k, std::cout); return (int)good; } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 5f3410683..f6946508d 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -45,6 +45,8 @@ using std::vector; extern std::vector devices; +namespace { + template typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, int error_mag) { @@ -59,11 +61,10 @@ struct vec_type { template using vec_type_t = typename vec_type::type; -// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding -template <> -static bool check_equal_matrix, vec_type_t>( - vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { +// Check for int32_t and Ts=float as small differences in the reference become large after rounding +inline bool check_equal_matrix_int(vec_type_t &M, vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -83,7 +84,13 @@ static bool check_equal_matrix, vec_type_t>( return good; } -namespace { +template +inline bool check_mat(vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, + int n, int ld, int error_mag, std::ostream &out) { + if constexpr (std::is_same::value) + return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); +} template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { @@ -291,9 +298,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; } - bool good = check_equal_matrix, vec_type_t>( - C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, - stride_c * batch_size, tol_scalar * k, std::cout); + bool good = check_mat(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, + 1, stride_c * batch_size, tol_scalar * k, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index cbe062b42..7ecbf1b77 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -45,17 +45,18 @@ using std::vector; extern std::vector devices; +namespace { + template typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, int error_mag) { return (std::abs(x - x_ref) <= 1); } -// Specialized check for Tc=int32_t and Ts=float as small differences in the reference become large after rounding -template <> -static bool check_equal_matrix(const int32_t *M, const int32_t *M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, - int error_mag, std::ostream &out) { +// Check for int32_t and Ts=float as small differences in the reference become large after rounding +static bool check_equal_matrix_int(const int32_t *M, const int32_t *M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -75,7 +76,13 @@ static bool check_equal_matrix(const int32_t *M, const int32_t *M_ref, return good; } -namespace { +template +inline bool check_mat(const T *M, const T *M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { + if constexpr (std::is_same::value) + return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); +} template int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { @@ -362,8 +369,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size[i]; j++) { copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_cast_ref_array[idx]); - good = good && check_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], - n[i], ldc[i], tol_scalar * k[i], std::cout); + good = good && check_mat(c_array[idx], c_cast_ref_array[idx], layout, m[i], n[i], + ldc[i], tol_scalar * k[i], std::cout); idx++; } } From 28e098a1af4e442096dff9a5e044aa7679ffb522 Mon Sep 17 00:00:00 2001 From: Aidan Date: Fri, 14 Jun 2024 16:09:40 +0100 Subject: [PATCH 26/30] Remove static/inline --- tests/unit_tests/blas/batch/gemm_batch_stride.cpp | 10 +++++----- tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp | 10 +++++----- tests/unit_tests/blas/batch/gemm_batch_usm.cpp | 9 ++++----- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index f92a64569..5459e086f 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -70,9 +70,9 @@ template using vec_type_t = typename vec_type::type; // Check for int32_t and Ts=float as small differences in the reference become large after rounding -inline bool check_equal_matrix_int(acc_type_t &M, vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, - std::ostream &out) { +bool check_equal_matrix_int(acc_type_t &M, vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -93,8 +93,8 @@ inline bool check_equal_matrix_int(acc_type_t &M, vec_type_t & } template -inline bool check_mat(acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, - int n, int ld, int error_mag, std::ostream &out) { +bool check_mat(acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { if constexpr (std::is_same::value) return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index f6946508d..e9e57904a 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -62,9 +62,9 @@ template using vec_type_t = typename vec_type::type; // Check for int32_t and Ts=float as small differences in the reference become large after rounding -inline bool check_equal_matrix_int(vec_type_t &M, vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, - std::ostream &out) { +bool check_equal_matrix_int(vec_type_t &M, vec_type_t &M_ref, + oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -85,8 +85,8 @@ inline bool check_equal_matrix_int(vec_type_t &M, vec_type_t & } template -inline bool check_mat(vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, - int n, int ld, int error_mag, std::ostream &out) { +bool check_mat(vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { if constexpr (std::is_same::value) return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 7ecbf1b77..d2ffac186 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -54,9 +54,8 @@ typename std::enable_if::value, bool>::type check_equal_int } // Check for int32_t and Ts=float as small differences in the reference become large after rounding -static bool check_equal_matrix_int(const int32_t *M, const int32_t *M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, - std::ostream &out) { +bool check_equal_matrix_int(const int32_t *M, const int32_t *M_ref, oneapi::mkl::layout layout, + int m, int n, int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -77,8 +76,8 @@ static bool check_equal_matrix_int(const int32_t *M, const int32_t *M_ref, } template -inline bool check_mat(const T *M, const T *M_ref, oneapi::mkl::layout layout, int m, int n, int ld, - int error_mag, std::ostream &out) { +bool check_mat(const T *M, const T *M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { if constexpr (std::is_same::value) return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); From b7a650f7f3dcca8c423b3018e19924a7232589af Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 18 Jun 2024 16:47:13 +0100 Subject: [PATCH 27/30] refactor almost_equal_matrix check --- .../blas/batch/gemm_batch_stride.cpp | 59 +------------------ .../blas/batch/gemm_batch_stride_usm.cpp | 50 +--------------- .../unit_tests/blas/batch/gemm_batch_usm.cpp | 41 +------------ tests/unit_tests/blas/include/test_common.hpp | 53 +++++++++++++++++ 4 files changed, 62 insertions(+), 141 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 5459e086f..92b1a8d3e 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -47,59 +47,6 @@ extern std::vector devices; namespace { -template -typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, - int error_mag) { - return (std::abs(x - x_ref) <= 1); -} - -template -struct acc_type { - typedef host_accessor type; -}; - -template -using acc_type_t = typename acc_type::type; - -template -struct vec_type { - typedef vector> type; -}; - -template -using vec_type_t = typename vec_type::type; - -// Check for int32_t and Ts=float as small differences in the reference become large after rounding -bool check_equal_matrix_int(acc_type_t &M, vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, - std::ostream &out) { - bool good = true; - int idx, count = 0; - for (int j = 0; j < n; j++) { - for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; - if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { - out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] - << " vs. Reference " << M_ref[idx] << std::endl; - good = false; - count++; - if (count > MAX_NUM_PRINT) - return good; - } - } - } - - return good; -} - -template -bool check_mat(acc_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { - if constexpr (std::is_same::value) - return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); - return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); -} - template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Prepare data. @@ -276,9 +223,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { C_cast_ref[i] = C_ref[i]; } auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = - check_mat(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, tol_scalar * k, std::cout); + bool good = check_almost_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + tol_scalar * k, std::cout); return (int)good; } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index e9e57904a..27a32ed3a 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -47,51 +47,6 @@ extern std::vector devices; namespace { -template -typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, - int error_mag) { - return (std::abs(x - x_ref) <= 1); -} - -template -struct vec_type { - typedef vector> type; -}; - -template -using vec_type_t = typename vec_type::type; - -// Check for int32_t and Ts=float as small differences in the reference become large after rounding -bool check_equal_matrix_int(vec_type_t &M, vec_type_t &M_ref, - oneapi::mkl::layout layout, int m, int n, int ld, int error_mag, - std::ostream &out) { - bool good = true; - int idx, count = 0; - for (int j = 0; j < n; j++) { - for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; - if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { - out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] - << " vs. Reference " << M_ref[idx] << std::endl; - good = false; - count++; - if (count > MAX_NUM_PRINT) - return good; - } - } - } - - return good; -} - -template -bool check_mat(vec_type_t &M, vec_type_t &M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { - if constexpr (std::is_same::value) - return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); - return check_equal_matrix>(M, M_ref, layout, m, n, ld, error_mag, out); -} - template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Catch asynchronous exceptions. @@ -298,8 +253,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; } - bool good = check_mat(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, tol_scalar * k, std::cout); + bool good = check_almost_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + tol_scalar * k, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index d2ffac186..37dcb8ae6 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -47,42 +47,6 @@ extern std::vector devices; namespace { -template -typename std::enable_if::value, bool>::type check_equal_int(fp x, fp x_ref, - int error_mag) { - return (std::abs(x - x_ref) <= 1); -} - -// Check for int32_t and Ts=float as small differences in the reference become large after rounding -bool check_equal_matrix_int(const int32_t *M, const int32_t *M_ref, oneapi::mkl::layout layout, - int m, int n, int ld, int error_mag, std::ostream &out) { - bool good = true; - int idx, count = 0; - for (int j = 0; j < n; j++) { - for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; - if (!check_equal_int(M[idx], M_ref[idx], error_mag)) { - out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] - << " vs. Reference " << M_ref[idx] << std::endl; - good = false; - count++; - if (count > MAX_NUM_PRINT) - return good; - } - } - } - - return good; -} - -template -bool check_mat(const T *M, const T *M_ref, oneapi::mkl::layout layout, int m, int n, int ld, - int error_mag, std::ostream &out) { - if constexpr (std::is_same::value) - return check_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); - return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); -} - template int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { // Catch asynchronous exceptions. @@ -368,8 +332,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size[i]; j++) { copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_cast_ref_array[idx]); - good = good && check_mat(c_array[idx], c_cast_ref_array[idx], layout, m[i], n[i], - ldc[i], tol_scalar * k[i], std::cout); + good = + good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], + n[i], ldc[i], tol_scalar * k[i], std::cout); idx++; } } diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index c90ffd813..9ad67ed5a 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -655,4 +655,57 @@ bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, i return good; } +// Helper for using std::result_of for evalutation operator[] return type +template +struct access_index { + auto operator()(int i) { + return M[i]; + } + T *M[0]; +}; + +// Helper for checking if a matrix/vector/accessor structure returns an integral type +template +constexpr bool is_matrix_type_integral() { + return std::is_integral_v(int)>::type>; +} + +template +typename std::enable_if::value, bool>::type check_almost_equal_int( + fp x, fp x_ref, int error_mag) { + return (std::abs(x - x_ref) <= 1); +} + +template +bool check_almost_equal_matrix_int(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { + static_assert(is_matrix_type_integral() && is_matrix_type_integral()); + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_almost_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + +template +bool check_almost_equal_matrix(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { + // Only call if returned dtype is integral + if constexpr (is_matrix_type_integral() && is_matrix_type_integral()) + return check_almost_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); +} + #endif /* header guard */ From a7b21c09214ac25159d46aadf2f9810fc9c418fe Mon Sep 17 00:00:00 2001 From: Aidan Date: Tue, 18 Jun 2024 18:23:34 +0100 Subject: [PATCH 28/30] Remove static --- tests/unit_tests/blas/include/test_common.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 9ad67ed5a..7358e4361 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -564,8 +564,8 @@ bool check_equal_trsv_vector(vec1 &v, vec2 &v_ref, int n, int inc, int error_mag } template -static bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, - int ld, int error_mag, std::ostream &out) { +bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -586,8 +586,8 @@ static bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, } template -static bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout, int m, - int n, int ld, int error_mag, std::ostream &out) { +bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { @@ -608,9 +608,9 @@ static bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout } template -static bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, - oneapi::mkl::uplo upper_lower, int m, int n, int ld, int error_mag, - std::ostream &out) { +bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, + oneapi::mkl::uplo upper_lower, int m, int n, int ld, int error_mag, + std::ostream &out) { bool good = true; int idx, count = 0; for (int j = 0; j < n; j++) { From ead997f04fea603e3c406399ecf4afb6f7308bd3 Mon Sep 17 00:00:00 2001 From: Aidan Date: Thu, 20 Jun 2024 11:33:16 +0100 Subject: [PATCH 29/30] Fix int32_t check --- tests/unit_tests/blas/include/test_common.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 7358e4361..594e97e7f 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -658,16 +658,16 @@ bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, i // Helper for using std::result_of for evalutation operator[] return type template struct access_index { - auto operator()(int i) { - return M[i]; + auto operator()(T M) { + return M[0]; } - T *M[0]; }; // Helper for checking if a matrix/vector/accessor structure returns an integral type template constexpr bool is_matrix_type_integral() { - return std::is_integral_v(int)>::type>; + return std::is_integral_v< + std::remove_reference_t(T)>::type>>; } template From 164ed6eb0b4568b506181a1b1e3f1df5805d5092 Mon Sep 17 00:00:00 2001 From: Aidan Date: Thu, 20 Jun 2024 15:14:13 +0100 Subject: [PATCH 30/30] Pass error_mag through all functions --- tests/unit_tests/blas/batch/gemm_batch_stride.cpp | 9 +++++---- .../unit_tests/blas/batch/gemm_batch_stride_usm.cpp | 9 +++++---- tests/unit_tests/blas/batch/gemm_batch_usm.cpp | 13 ++++++------- tests/unit_tests/blas/include/test_common.hpp | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 92b1a8d3e..12af18ec9 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -215,9 +215,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - int tol_scalar = std::is_same_v ? 10 : 60; - if (main_queue.get_device().is_cpu()) - tol_scalar = 100; + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; @@ -225,7 +226,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { auto C_accessor = C_buffer.template get_host_access(read_only); bool good = check_almost_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, stride_c * batch_size, - tol_scalar * k, std::cout); + error_mag, std::cout); return (int)good; } diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index 27a32ed3a..97f2dd086 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -246,16 +246,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - int tol_scalar = std::is_same_v ? 10 : 60; - if (main_queue.get_device().is_cpu()) - tol_scalar = 100; + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; for (size_t i = 0; i < C_ref.size(); ++i) { C_cast_ref[i] = C_ref[i]; } bool good = check_almost_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, stride_c * batch_size, - tol_scalar * k, std::cout); + error_mag, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 37dcb8ae6..a651f9ae3 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -322,19 +322,18 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. int tol_scalar = 10; - // Scale the tolerance for when we generate int8_t, as input range is [-128, 127] - // rather than [-1,1] - if (std::is_same_v && std::is_same_v) - tol_scalar *= 256; idx = 0; for (i = 0; i < group_count; i++) { for (j = 0; j < group_size[i]; j++) { + int error_mag = tol_scalar * k[i]; + if (std::is_same_v) + error_mag = 1; + copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_cast_ref_array[idx]); - good = - good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, m[i], - n[i], ldc[i], tol_scalar * k[i], std::cout); + good = good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, + m[i], n[i], ldc[i], error_mag, std::cout); idx++; } } diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 594e97e7f..5d607991e 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -673,7 +673,7 @@ constexpr bool is_matrix_type_integral() { template typename std::enable_if::value, bool>::type check_almost_equal_int( fp x, fp x_ref, int error_mag) { - return (std::abs(x - x_ref) <= 1); + return (std::abs(x - x_ref) <= error_mag); } template