Skip to content

Commit

Permalink
Merge branch 'ttg-device-support-master-coro-with-stream-tasks' of gi…
Browse files Browse the repository at this point in the history
…thub.com:devreal/ttg into ttg-device-support-master-coro-with-stream-tasks
  • Loading branch information
devreal committed Nov 16, 2023
2 parents e20758a + e9271ff commit af6f684
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions examples/spmm/spmm_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ using namespace ttg;

#include "ttg/serialization/std/pair.h"

#if defined(TTG_HAVE_LEVEL_ZERO)
#include <oneapi/mkl.hpp>
#include <sys/time.h>
#endif

#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)

template <typename _T, class _Range, class _Storage>
Expand Down Expand Up @@ -249,7 +254,7 @@ struct DeviceTensor : public ttg::TTValue<DeviceTensor<_T, _Range, _Storage>>
};

using scalar_t = double;
#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS)
#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS) || defined(TTG_HAVE_LEVEL_ZERO)
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t, TiledArray::device_pinned_allocator<scalar_t>>,
btas::Handle::shared_ptr>>;
Expand Down Expand Up @@ -306,7 +311,36 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) {
B.b.current_device_ptr(), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}

#elif defined(TTG_HAVE_LEVEL_ZERO)

#if defined(DEBUG_SYNCHRONOUS)
try {
#endif /* DEBUG_SYNCHRONOUS */
cl::sycl::event gemm_event;
gemm_event = oneapi::mkl::blas::gemm(lz_queue(),
oneapi::mkl::transpose::N, oneapi::mkl::transpose::N,
C.extent(0), C.extent(1), A.extent(1),
alpha, A.b.current_device_ptr(), A.extent(0),
B.b.current_device_ptr(), B.extent(0),
beta, C.b.current_device_ptr(), C.extent(0));
#if defined(DEBUG_SYNCHRONOUS)
gemm_event.wait();
} catch (const oneapi::mkl::invalid_argument &e) {
std::cerr << "OneAPI MKL BLAS GEMM throws invalid argument exception" << std::endl;
} catch (const oneapi::mkl::unsupported_device &e) {
std::cerr << "OneAPI MKL BLAS GEMM throws unsuported device exception" << std::endl;
} catch (const oneapi::mkl::host_bad_alloc &e) {
std::cerr << "OneAPI MKL BLAS GEMM throws host bad allocation exception" << std::endl;
} catch (const oneapi::mkl::device_bad_alloc &e) {
std::cerr << "OneAPI MKL BLAS GEMM throws device bad allocation exception" << std::endl;
} catch (const oneapi::mkl::unimplemented &e) {
std::cerr << "OneAPI MKL BLAS GEMM throws unimplemented exception" << std::endl;
} catch (const std::exception& e) {
std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception" << std::endl;
} catch (...) {
std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception that is also badly formatted..." << std::endl;
}
#endif /* DEBUG_SYNCHRONOUS */
#endif
}

Expand Down Expand Up @@ -734,6 +768,9 @@ class SpMM25D {
#elif defined(TTG_HAVE_HIPBLAS)
static constexpr bool have_hip_op = true;
#warning SPMM using HIP implementation
#elif defined(TTG_HAVE_LEVEL_ZERO)
static constexpr bool have_level_zero_op = true;
#warning SPMM using LEVEL_ZERO implementation
#else
#error No valid device implementation found!
#endif
Expand Down

0 comments on commit af6f684

Please sign in to comment.