Skip to content

Commit

Permalink
Add ttg::device::Stream encapsulating a lower-level stream object
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Oct 14, 2024
1 parent 874e769 commit ba03182
Showing 1 changed file with 49 additions and 66 deletions.
115 changes: 49 additions & 66 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@
#include "ttg/execution.h"
#include "ttg/impl_selector.h"
#include "ttg/fwd.h"
#include "ttg/util/meta.h"



namespace ttg::device {

#if defined(TTG_HAVE_CUDA)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
#elif defined(TTG_HAVE_HIP)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
#elif defined(TTG_HAVE_LEVEL_ZERO)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
#else
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Invalid;
#endif

/// Represents a device in a specific execution space
class Device {
int m_id = 0;
Expand Down Expand Up @@ -74,52 +65,64 @@ namespace std {
}
} // namespace std

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>

namespace ttg::device {
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local cudaStream_t current_stream_ts = 0; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
}

inline void set_current(int device, cudaStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::CUDA);
current_stream_ts = stream;
}
namespace detail {
template<typename Stream>
struct default_stream {
static constexpr const Stream value = 0;
};
template<typename Stream>
constexpr const Stream default_stream_v = default_stream<Stream>::value;
} // namespace detail

inline
Device current_device() {
return detail::current_device_ts;
}

inline
cudaStream_t current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
using Stream = cudaStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_HIP)

#include <hip/hip_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
using Stream = hipStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_LEVEL_ZERO)
#include <CL/sycl.hpp>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
using Stream = std::add_reference_t<sycl::queue>;
} // namespace ttg::device
#else
namespace ttg::device {
struct Stream { };
namespace detail {
template<>
struct default_stream<Stream> {
static constexpr const Stream value = {};
};
} // namespace detail
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Host;
} // namespace ttg::device
#endif

namespace ttg::device {

#if !defined(TTG_HAVE_LEVEL_ZERO)
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local hipStream_t current_stream_ts = 0; // default stream
inline thread_local Stream current_stream_ts = detail::default_stream_v<Stream>; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
current_stream_ts = detail::default_stream_v<Stream>;
}

inline void set_current(int device, hipStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::HIP);
inline void set_current(int device, Stream stream) {
current_device_ts = ttg::device::Device(device, available_execution_space);
current_stream_ts = stream;
}
} // namespace detail
Expand All @@ -130,16 +133,16 @@ namespace ttg::device {
}

inline
hipStream_t current_stream() {
Stream current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#elif defined(TTG_HAVE_LEVEL_ZERO)

#include <CL/sycl.hpp>
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}

namespace ttg::device {
#else // TTG_HAVE_LEVEL_ZERO
/* SYCL needs special treatment because it uses pointers/references */
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local sycl::queue* current_stream_ts = nullptr; // default stream
Expand All @@ -165,26 +168,6 @@ namespace ttg::device {
sycl::queue& current_stream() {
return *detail::current_stream_ts;
}
} // namespace ttg

#else
#endif // TTG_HAVE_LEVEL_ZERO

namespace ttg::device {
inline Device current_device() {
return {};
}

template<ttg::ExecutionSpace Space = ttg::ExecutionSpace::Invalid>
inline const void* current_stream() {
static_assert(Space != ttg::ExecutionSpace::Invalid,
"TTG was built without any known device support so we cannot provide a current stream!");
return nullptr;
}
} // namespace ttg
#endif // defined(TTG_HAVE_HIP)

namespace ttg::device {
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}
}

0 comments on commit ba03182

Please sign in to comment.