Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream improvements needed for MRA on devices #298

Merged
merged 8 commits into from
Nov 12, 2024
4 changes: 3 additions & 1 deletion ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#ifndef TTG_BUFFER_H
#define TTG_BUFFER_H

#include <memory>

#include "ttg/fwd.h"

namespace ttg {

template<typename T, typename Allocator = std::allocator<T>>
template<typename T, typename Allocator = std::allocator<std::decay_t<T>>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

} // namespace ttg
Expand Down
115 changes: 49 additions & 66 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@
#include "ttg/execution.h"
#include "ttg/impl_selector.h"
#include "ttg/fwd.h"
#include "ttg/util/meta.h"



namespace ttg::device {

#if defined(TTG_HAVE_CUDA)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
#elif defined(TTG_HAVE_HIP)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
#elif defined(TTG_HAVE_LEVEL_ZERO)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
#else
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Invalid;
#endif

/// Represents a device in a specific execution space
class Device {
int m_id = 0;
Expand Down Expand Up @@ -74,52 +65,64 @@ namespace std {
}
} // namespace std

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>

namespace ttg::device {
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local cudaStream_t current_stream_ts = 0; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
}

inline void set_current(int device, cudaStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::CUDA);
current_stream_ts = stream;
}
namespace detail {
template<typename Stream>
struct default_stream {
static constexpr const Stream value = 0;
};
template<typename Stream>
constexpr const Stream default_stream_v = default_stream<Stream>::value;
} // namespace detail

inline
Device current_device() {
return detail::current_device_ts;
}

inline
cudaStream_t current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
using Stream = cudaStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_HIP)

#include <hip/hip_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
using Stream = hipStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_LEVEL_ZERO)
#include <CL/sycl.hpp>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
using Stream = std::add_reference_t<sycl::queue>;
} // namespace ttg::device
#else
namespace ttg::device {
struct Stream { };
namespace detail {
template<>
struct default_stream<Stream> {
static constexpr const Stream value = {};
};
} // namespace detail
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Host;
} // namespace ttg::device
#endif

namespace ttg::device {

#if !defined(TTG_HAVE_LEVEL_ZERO)
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local hipStream_t current_stream_ts = 0; // default stream
inline thread_local Stream current_stream_ts = detail::default_stream_v<Stream>; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
current_stream_ts = detail::default_stream_v<Stream>;
}

inline void set_current(int device, hipStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::HIP);
inline void set_current(int device, Stream stream) {
current_device_ts = ttg::device::Device(device, available_execution_space);
current_stream_ts = stream;
}
} // namespace detail
Expand All @@ -130,16 +133,16 @@ namespace ttg::device {
}

inline
hipStream_t current_stream() {
Stream current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#elif defined(TTG_HAVE_LEVEL_ZERO)

#include <CL/sycl.hpp>
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}

namespace ttg::device {
#else // TTG_HAVE_LEVEL_ZERO
/* SYCL needs special treatment because it uses pointers/references */
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local sycl::queue* current_stream_ts = nullptr; // default stream
Expand All @@ -165,26 +168,6 @@ namespace ttg::device {
sycl::queue& current_stream() {
return *detail::current_stream_ts;
}
} // namespace ttg

#else
#endif // TTG_HAVE_LEVEL_ZERO

namespace ttg::device {
inline Device current_device() {
return {};
}

template<ttg::ExecutionSpace Space = ttg::ExecutionSpace::Invalid>
inline const void* current_stream() {
static_assert(Space != ttg::ExecutionSpace::Invalid,
"TTG was built without any known device support so we cannot provide a current stream!");
return nullptr;
}
} // namespace ttg
#endif // defined(TTG_HAVE_HIP)

namespace ttg::device {
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}
}
93 changes: 90 additions & 3 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ttg::device {

template <typename... Ts>
struct wait_kernel_t {
std::tuple<Ts &...> ties;
std::tuple<std::add_lvalue_reference_t<Ts>...> ties;

/* always suspend */
constexpr bool await_ready() const noexcept { return false; }
Expand Down Expand Up @@ -270,7 +270,7 @@ namespace ttg::device {
/* overload for iterable types that extracts the type of the first element */
template<typename T>
struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
using key_type = decltype(*std::begin(std::get<0>(std::declval<T>())));
using key_type = decltype(*std::begin(std::declval<T>()));
};

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
Expand Down Expand Up @@ -364,6 +364,71 @@ namespace ttg::device {
ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
}
}

/**
* broadcastk
*/

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
typename... out_keysT, typename... out_valuesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
std::get<I>(t).broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
detail::broadcastk<KeyId+1, Is...>(keylists, t);
}
}

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
typename... out_keysT, typename... out_valuesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists) {
using key_t = typename broadcast_keylist_trait<
std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
>::key_type;
auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
terminal_ptr->broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
}
}

/* overload with explicit terminals */
template <size_t I, size_t... Is, typename RangesT,
typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
}
}

/* overload with implicit terminals */
template <size_t I, size_t... Is, typename RangesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
"Size of keylist tuple must match the number of output terminals");
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
}
}
} // namespace detail

/* overload with explicit terminals and keylist passed by const reference */
Expand All @@ -389,11 +454,33 @@ namespace ttg::device {
std::move(copy_handler))};
}

/* overload with explicit terminals and keylist passed by const reference */
template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
inline detail::send_t broadcastk(rangeT &&keylist,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{
detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
}

/* overload with implicit terminals and keylist passed by const reference */
template <size_t i, typename rangeT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcastk(rangeT &&keylist) {
if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
} else {
return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
}
}

template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
std::vector<device::detail::send_t> forward(Args&&... args) {
// TODO: check the cost of this!
return std::vector{std::forward<Args>(args)...};
return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
}

/*******************************************
Expand Down
Loading
Loading