diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index 51fb39d5f..d9f8f4ecb 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -31,8 +31,6 @@ #include "ttg.h" #include "../ttg_matrix.h" -using namespace ttg; - #include "ttg/util/future.h" #include "ttg/util/multiindex.h" @@ -40,6 +38,11 @@ using namespace ttg; #include "ttg/util/bug.h" +#include "devicetensor.h" +#include "devicegemm.h" + +using namespace ttg; + #if defined(TTG_ENABLE_CUDA) #define HAVE_SPMM_DEVICE 1 static constexpr ttg::ExecutionSpace space = ttg::ExecutionSpace::CUDA; @@ -572,10 +575,6 @@ class SpMM25D { ttg::typelist> { static constexpr const bool is_device_space = (Space_ != ttg::ExecutionSpace::Host); using task_return_type = std::conditional_t; - /* communicate to the runtime which device we support (if any) */ - static constexpr bool have_cuda_op = (Space_ == ttg::ExecutionSpace::CUDA); - static constexpr bool have_hip_op = (Space_ == ttg::ExecutionSpace::HIP); - static constexpr bool have_level_zero_op = (Space_ == ttg::ExecutionSpace::L0); void release_next_k(long k) { assert(k_cnt_.size() > k); @@ -597,6 +596,11 @@ class SpMM25D { public: using baseT = typename MultiplyAdd::ttT; + /* communicate to the runtime which device we support (if any) */ + static constexpr bool have_cuda_op = (Space_ == ttg::ExecutionSpace::CUDA); + static constexpr bool have_hip_op = (Space_ == ttg::ExecutionSpace::HIP); + static constexpr bool have_level_zero_op = (Space_ == ttg::ExecutionSpace::L0); + MultiplyAdd(Edge, Blk> &a_ijk, Edge, Blk> &b_ijk, Edge, Blk> &c_ijk, Edge, Blk> &c, const std::vector> &a_cols_of_row, const std::vector> &b_rows_of_col, const std::vector &mTiles,