Skip to content

Commit

Permalink
Merge branch 'master' into serialize-buffer-query
Browse files Browse the repository at this point in the history
  • Loading branch information
devreal committed Nov 13, 2024
2 parents 2c62403 + 3636049 commit 2d70423
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 285 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
-DCMAKE_CXX_STANDARD=20
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- name: Install prerequisite MacOS packages
if: ${{ matrix.os == 'macos-latest' }}
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
message("::set-output name=timestamp::${current_date}")
- name: Setup ccache cache files
uses: actions/cache@v1.1.0
uses: actions/cache@v4
with:
path: ${{github.workspace}}/build/.ccache
key: ${{ matrix.config.name }}-ccache-${{ steps.ccache_cache_timestamp.outputs.timestamp }}
Expand Down
4 changes: 3 additions & 1 deletion ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef TTG_BUFFER_H
#define TTG_BUFFER_H

#include <memory>

#include "ttg/fwd.h"
#include "ttg/serialization.h"
#include <memory>

namespace ttg {

template<typename T, typename Allocator = std::allocator<T>>
template<typename T, typename Allocator = std::allocator<std::decay_t<T>>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

namespace meta {
Expand Down
115 changes: 49 additions & 66 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@
#include "ttg/execution.h"
#include "ttg/impl_selector.h"
#include "ttg/fwd.h"
#include "ttg/util/meta.h"



namespace ttg::device {

#if defined(TTG_HAVE_CUDA)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
#elif defined(TTG_HAVE_HIP)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
#elif defined(TTG_HAVE_LEVEL_ZERO)
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
#else
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Invalid;
#endif

/// Represents a device in a specific execution space
class Device {
int m_id = 0;
Expand Down Expand Up @@ -74,52 +65,64 @@ namespace std {
}
} // namespace std

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>

namespace ttg::device {
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local cudaStream_t current_stream_ts = 0; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
}

inline void set_current(int device, cudaStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::CUDA);
current_stream_ts = stream;
}
namespace detail {
template<typename Stream>
struct default_stream {
static constexpr const Stream value = 0;
};
template<typename Stream>
constexpr const Stream default_stream_v = default_stream<Stream>::value;
} // namespace detail

inline
Device current_device() {
return detail::current_device_ts;
}

inline
cudaStream_t current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#if defined(TTG_HAVE_CUDA)
#include <cuda_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::CUDA;
using Stream = cudaStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_HIP)

#include <hip/hip_runtime.h>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::HIP;
using Stream = hipStream_t;
} // namespace ttg::device
#elif defined(TTG_HAVE_LEVEL_ZERO)
#include <CL/sycl.hpp>
namespace ttg::device {
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::L0;
using Stream = std::add_reference_t<sycl::queue>;
} // namespace ttg::device
#else
namespace ttg::device {
struct Stream { };
namespace detail {
template<>
struct default_stream<Stream> {
static constexpr const Stream value = {};
};
} // namespace detail
constexpr ttg::ExecutionSpace available_execution_space = ttg::ExecutionSpace::Host;
} // namespace ttg::device
#endif

namespace ttg::device {

#if !defined(TTG_HAVE_LEVEL_ZERO)
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local hipStream_t current_stream_ts = 0; // default stream
inline thread_local Stream current_stream_ts = detail::default_stream_v<Stream>; // default stream

inline void reset_current() {
current_device_ts = {};
current_stream_ts = 0;
current_stream_ts = detail::default_stream_v<Stream>;
}

inline void set_current(int device, hipStream_t stream) {
current_device_ts = ttg::device::Device(device, ttg::ExecutionSpace::HIP);
inline void set_current(int device, Stream stream) {
current_device_ts = ttg::device::Device(device, available_execution_space);
current_stream_ts = stream;
}
} // namespace detail
Expand All @@ -130,16 +133,16 @@ namespace ttg::device {
}

inline
hipStream_t current_stream() {
Stream current_stream() {
return detail::current_stream_ts;
}
} // namespace ttg

#elif defined(TTG_HAVE_LEVEL_ZERO)

#include <CL/sycl.hpp>
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}

namespace ttg::device {
#else // TTG_HAVE_LEVEL_ZERO
/* SYCL needs special treatment because it uses pointers/references */
namespace detail {
inline thread_local ttg::device::Device current_device_ts = {};
inline thread_local sycl::queue* current_stream_ts = nullptr; // default stream
Expand All @@ -165,26 +168,6 @@ namespace ttg::device {
sycl::queue& current_stream() {
return *detail::current_stream_ts;
}
} // namespace ttg

#else
#endif // TTG_HAVE_LEVEL_ZERO

namespace ttg::device {
inline Device current_device() {
return {};
}

template<ttg::ExecutionSpace Space = ttg::ExecutionSpace::Invalid>
inline const void* current_stream() {
static_assert(Space != ttg::ExecutionSpace::Invalid,
"TTG was built without any known device support so we cannot provide a current stream!");
return nullptr;
}
} // namespace ttg
#endif // defined(TTG_HAVE_HIP)

namespace ttg::device {
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}
}
95 changes: 90 additions & 5 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ttg::device {

template <typename... Ts>
struct wait_kernel_t {
std::tuple<Ts &...> ties;
std::tuple<std::add_lvalue_reference_t<Ts>...> ties;

/* always suspend */
constexpr bool await_ready() const noexcept { return false; }
Expand Down Expand Up @@ -270,7 +270,7 @@ namespace ttg::device {
/* overload for iterable types that extracts the type of the first element */
template<typename T>
struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
using key_type = decltype(*std::begin(std::get<0>(std::declval<T>())));
using key_type = decltype(*std::begin(std::declval<T>()));
};

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
Expand Down Expand Up @@ -306,8 +306,7 @@ namespace ttg::device {
}
}

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
typename... out_keysT, typename... out_valuesT>
template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT>
inline void broadcast(const std::tuple<RangesT...> &keylists, valueT &&value) {
using key_t = typename broadcast_keylist_trait<
std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
Expand Down Expand Up @@ -364,6 +363,70 @@ namespace ttg::device {
ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
}
}

/**
* broadcastk
*/

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
typename... out_keysT, typename... out_valuesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
std::get<I>(t).broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
detail::broadcastk<KeyId+1, Is...>(keylists, t);
}
}

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists) {
using key_t = typename broadcast_keylist_trait<
std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
>::key_type;
auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
terminal_ptr->broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
}
}

/* overload with explicit terminals */
template <size_t I, size_t... Is, typename RangesT,
typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
}
}

/* overload with implicit terminals */
template <size_t I, size_t... Is, typename RangesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
"Size of keylist tuple must match the number of output terminals");
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
}
}
} // namespace detail

/* overload with explicit terminals and keylist passed by const reference */
Expand All @@ -389,11 +452,33 @@ namespace ttg::device {
std::move(copy_handler))};
}

/* overload with explicit terminals and keylist passed by const reference */
template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
inline detail::send_t broadcastk(rangeT &&keylist,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{
detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
}

/* overload with implicit terminals and keylist passed by const reference */
template <size_t i, typename rangeT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcastk(rangeT &&keylist) {
if constexpr (std::is_rvalue_reference_v<decltype(keylist)>) {
return detail::send_t{detail::broadcastk_coro<i>(std::forward<rangeT>(keylist))};
} else {
return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
}
}

template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
std::vector<device::detail::send_t> forward(Args&&... args) {
// TODO: check the cost of this!
return std::vector{std::forward<Args>(args)...};
return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
}

/*******************************************
Expand Down
8 changes: 3 additions & 5 deletions ttg/ttg/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,15 @@ namespace ttg {
std::get<i>(t).broadcast(keylist, copy_handler(std::forward<valueT>(value)));
}

template <typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
template <typename rangeT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
inline void broadcast(std::size_t i, const rangeT &keylist, valueT &&value) {
detail::value_copy_handler<Runtime> copy_handler;
using key_t = decltype(*std::begin(keylist));
auto *terminal_ptr = detail::get_out_terminal<key_t, valueT>(i, "ttg::broadcast(keylist, value)");
terminal_ptr->broadcast(keylist, copy_handler(std::forward<valueT>(value)));
}

template <size_t i, typename rangeT, typename valueT, typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
template <size_t i, typename rangeT, typename valueT, ttg::Runtime Runtime = ttg::ttg_runtime>
inline void broadcast(const rangeT &keylist, valueT &&value) {
broadcast(i, keylist, std::forward<valueT>(value));
}
Expand Down Expand Up @@ -505,7 +503,7 @@ namespace ttg {
terminal_ptr->set_size(size);
}

template <size_t i, typename keyT, typename... out_keysT, typename... out_valuesT>
template <size_t i, typename keyT>
inline std::enable_if_t<!meta::is_void_v<keyT>, void> set_size(const keyT &key, const std::size_t size) {
set_size(i, key, size);
}
Expand Down
Loading

0 comments on commit 2d70423

Please sign in to comment.