diff --git a/examples/spmm/spmm_cuda.cc b/examples/spmm/spmm_cuda.cc index 60fb969fa..2c0dbffc5 100644 --- a/examples/spmm/spmm_cuda.cc +++ b/examples/spmm/spmm_cuda.cc @@ -49,6 +49,11 @@ using namespace ttg; #include "ttg/serialization/std/pair.h" +#if defined(TTG_HAVE_LEVEL_ZERO) +#include +#include +#endif + #if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE) template @@ -249,7 +254,7 @@ struct DeviceTensor : public ttg::TTValue> }; using scalar_t = double; -#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS) +#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS) || defined(TTG_HAVE_LEVEL_ZERO) using blk_t = DeviceTensor>, btas::Handle::shared_ptr>>; @@ -306,7 +311,36 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) { B.b.current_device_ptr(), B.extent(0), &beta, C.b.current_device_ptr(), C.extent(0)); } - +#elif defined(TTG_HAVE_LEVEL_ZERO) + +#if defined(DEBUG_SYNCHRONOUS) + try { +#endif /* DEBUG_SYNCHRONOUS */ + cl::sycl::event gemm_event; + gemm_event = oneapi::mkl::blas::gemm(lz_queue(), + oneapi::mkl::transpose::N, oneapi::mkl::transpose::N, + C.extent(0), C.extent(1), A.extent(1), + alpha, A.b.current_device_ptr(), A.extent(0), + B.b.current_device_ptr(), B.extent(0), + beta, C.b.current_device_ptr(), C.extent(0)); +#if defined(DEBUG_SYNCHRONOUS) + gemm_event.wait(); + } catch (const oneapi::mkl::invalid_argument &e) { + std::cerr << "OneAPI MKL BLAS GEMM throws invalid argument exception" << std::endl; + } catch (const oneapi::mkl::unsupported_device &e) { + std::cerr << "OneAPI MKL BLAS GEMM throws unsuported device exception" << std::endl; + } catch (const oneapi::mkl::host_bad_alloc &e) { + std::cerr << "OneAPI MKL BLAS GEMM throws host bad allocation exception" << std::endl; + } catch (const oneapi::mkl::device_bad_alloc &e) { + std::cerr << "OneAPI MKL BLAS GEMM throws device bad allocation exception" << std::endl; + } catch (const oneapi::mkl::unimplemented &e) { + std::cerr << "OneAPI MKL BLAS GEMM throws unimplemented exception" << std::endl; + } catch (const std::exception& e) { + std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception" << std::endl; + } catch (...) { + std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception that is also badly formatted..." << std::endl; + } +#endif /* DEBUG_SYNCHRONOUS */ #endif } @@ -734,6 +768,9 @@ class SpMM25D { #elif defined(TTG_HAVE_HIPBLAS) static constexpr bool have_hip_op = true; #warning SPMM using HIP implementation +#elif defined(TTG_HAVE_LEVEL_ZERO) + static constexpr bool have_level_zero_op = true; +#warning SPMM using LEVEL_ZERO implementation #else #error No valid device implementation found! #endif