diff --git a/ttg/ttg/device/device.h b/ttg/ttg/device/device.h index 244e9c944..6dcb3c722 100644 --- a/ttg/ttg/device/device.h +++ b/ttg/ttg/device/device.h @@ -4,21 +4,12 @@ #include "ttg/execution.h" #include "ttg/impl_selector.h" #include "ttg/fwd.h" +#include "ttg/util/meta.h" namespace ttg::device { -#if defined(TTG_HAVE_CUDA) - constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA; -#elif defined(TTG_HAVE_HIP) - constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP; -#elif defined(TTG_HAVE_LEVEL_ZERO) - constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0; -#else - constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Invalid; -#endif - /// Represents a device in a specific execution space class Device { int m_id = 0; @@ -74,52 +65,64 @@ namespace std { } } // namespace std -#if defined(TTG_HAVE_CUDA) -#include - namespace ttg::device { - namespace detail { - inline thread_local ttg::device::Device current_device_ts = {}; - inline thread_local cudaStream_t current_stream_ts = 0; // default stream - - inline void reset_current() { - current_device_ts = {}; - current_stream_ts = 0; - } - inline void set_current(int device, cudaStream_t stream) { - current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::CUDA); - current_stream_ts = stream; - } + namespace detail { + template + struct default_stream { + static constexpr const Stream value = 0; + }; + template + constexpr const Stream default_stream_v = default_stream::value; } // namespace detail - inline - Device current_device() { - return detail::current_device_ts; - } - - inline - cudaStream_t current_stream() { - return detail::current_stream_ts; - } } // namespace ttg +#if defined(TTG_HAVE_CUDA) +#include +namespace ttg::device { + constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA; + using Stream = cudaStream_t; +} // namespace ttg::device #elif defined(TTG_HAVE_HIP) - #include +namespace ttg::device { + constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP; + using Stream = hipStream_t; +} // namespace ttg::device +#elif defined(TTG_HAVE_LEVEL_ZERO) +#include +namespace ttg::device { + constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0; + using Stream = std::add_reference_t; +} // namespace ttg::device +#else +namespace ttg::device { + struct Stream { }; + namespace detail { + template<> + struct default_stream { + static constexpr const Stream value = {}; + }; + } // namespace detail + constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Host; +} // namespace ttg::device +#endif namespace ttg::device { + +#if !defined(TTG_HAVE_LEVEL_ZERO) namespace detail { inline thread_local ttg::device::Device current_device_ts = {}; - inline thread_local hipStream_t current_stream_ts = 0; // default stream + inline thread_local Stream current_stream_ts = detail::default_stream_v; // default stream inline void reset_current() { current_device_ts = {}; - current_stream_ts = 0; + current_stream_ts = detail::default_stream_v; } - inline void set_current(int device, hipStream_t stream) { - current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::HIP); + inline void set_current(int device, Stream stream) { + current_device_ts = ttg::device::Device(device, available_execution_space); current_stream_ts = stream; } } // namespace detail @@ -130,16 +133,16 @@ namespace ttg::device { } inline - hipStream_t current_stream() { + Stream current_stream() { return detail::current_stream_ts; } -} // namespace ttg - -#elif defined(TTG_HAVE_LEVEL_ZERO) -#include + inline int num_devices() { + return TTG_IMPL_NS::num_devices(); + } -namespace ttg::device { +#else // TTG_HAVE_LEVEL_ZERO + /* SYCL needs special treatment because it uses pointers/references */ namespace detail { inline thread_local ttg::device::Device current_device_ts = {}; inline thread_local sycl::queue* current_stream_ts = nullptr; // default stream @@ -165,26 +168,6 @@ namespace ttg::device { sycl::queue& current_stream() { return *detail::current_stream_ts; } -} // namespace ttg - -#else +#endif // TTG_HAVE_LEVEL_ZERO -namespace ttg::device { - inline Device current_device() { - return {}; - } - - template - inline const void* current_stream() { - static_assert(Space != ttg::ExecutionSpace::Invalid, - "TTG was built without any known device support so we cannot provide a current stream!"); - return nullptr; - } } // namespace ttg -#endif // defined(TTG_HAVE_HIP) - -namespace ttg::device { - inline int num_devices() { - return TTG_IMPL_NS::num_devices(); - } -}