Skip to content

Commit

Permalink
SPMM: Only use DeviceTensor for device execution
Browse files Browse the repository at this point in the history
We cannot fully serialize the DeviceTensor and we don't need
the DeviceTensor in host execution.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Oct 29, 2024
1 parent c3df91c commit 5d06eec
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ static constexpr ttg::ExecutionSpace space = ttg::ExecutionSpace::Host;
#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)
using scalar_t = double;

using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t
#if HAVE_SPMM_DEVICE
, TiledArray::device_pinned_allocator<scalar_t>
#endif // HAVE_SPMM_DEVICE
>,
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t,
TiledArray::device_pinned_allocator<scalar_t>>,
btas::Handle::shared_ptr>>;
//#include <atomic>
//static std::atomic<uint64_t> reduce_count = 0;
#else // HAVE_SPMM_DEVICE
using blk_t = btas::Tensor<scalar_t, btas::DEFAULT::range, btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;
#endif // HAVE_SPMM_DEVICE


#if defined(TTG_USE_PARSEC)
namespace ttg {
Expand Down Expand Up @@ -1586,7 +1586,7 @@ int main(int argc, char **argv) {
initialize(1, argv, cores);
}

#ifdef BTAS_IS_USABLE
#if defined(BTAS_IS_USABLE) && defined(TTG_PARSEC_IMPORTED)
// initialize MADNESS so that TA allocators can be created
madness::ParsecRuntime::initialize_with_existing_context(ttg::default_execution_context().impl().context());
madness::initialize(argc, argv, /* nthread = */ 1, /* quiet = */ true);
Expand Down

0 comments on commit 5d06eec

Please sign in to comment.