From e0e23d04fbcf8df8023666662a3596fb126af468 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Tue, 1 Oct 2024 13:35:49 +0100 Subject: [PATCH] Add support for double gemm_batch --- src/blas/backends/portblas/portblas_batch.cxx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index 29b31fe5b..1e11e8624 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -710,7 +710,12 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const double **b, std::int64_t *ldb, double *beta, double **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM using double"); + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa,