Skip to content

Commit

Permalink
Restrict calls to buffer_apply() to serializable types
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Nov 11, 2024
1 parent f5c39ad commit 9ced4aa
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 31 deletions.
28 changes: 28 additions & 0 deletions tests/unit/device_coro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ struct nested_value_t {
}
};

struct derived_value_t {
nested_value_t v;
};

#ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS
namespace madness {
namespace archive {

template <class Archive>
struct ArchiveLoadImpl<Archive, derived_value_t> {
static inline void load(const Archive& ar, derived_value_t& v) {
ar& v.v;
}
};

template <class Archive>
struct ArchiveStoreImpl<Archive, derived_value_t> {
static inline void store(const Archive& ar, const derived_value_t& v) {
ar& v.v;
}
};
} // namespace archive
} // namespace madness
#endif // TTG_SERIALIZATION_SUPPORTS_MADNESS

static_assert(madness::is_serializable_v<madness::archive::BufferInspectorArchive<ttg::detail::buffer_apply_dummy_fn>, derived_value_t>);
static_assert(ttg::detail::has_buffer_apply_v<derived_value_t>);

TEST_CASE("Device", "coro") {
SECTION("buffer-inspection") {
value_t v1;
Expand Down
25 changes: 24 additions & 1 deletion ttg/ttg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,24 @@ namespace ttg {
template<typename T, typename Allocator = std::allocator<T>>
using Buffer = TTG_IMPL_NS::Buffer<T, Allocator>;

namespace detail {
/**
* Type trait to check whether we can use serialization
* to inspect the buffers owned by an object passing
* through a task graph.
*/
template<typename T, typename Enabler = void>
struct has_buffer_apply : std::false_type
{ };

template<typename T>
constexpr const bool has_buffer_apply_v = has_buffer_apply<T>::value;
} // namespace detail

} // namespace ttg



#ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS
#include <madness/world/buffer_archive.h>

Expand Down Expand Up @@ -87,11 +103,18 @@ namespace madness {

namespace ttg::detail {
template<typename T, typename Fn>
requires(madness::is_serializable_v<madness::archive::BufferInspectorArchive<Fn>, T>)
requires(madness::is_serializable_v<madness::archive::BufferInspectorArchive<Fn>, std::decay<T>>)
void buffer_apply(T&& t, Fn&& fn) {
madness::archive::BufferInspectorArchive ar(std::forward<Fn>(fn));
ar & t;
}

using buffer_apply_dummy_fn = decltype([]<typename T, typename A>(const ttg::Buffer<T, A>&){});
template<typename T>
struct has_buffer_apply<T, std::enable_if_t<madness::is_serializable_v<madness::archive::BufferInspectorArchive<buffer_apply_dummy_fn>, std::decay_t<T>>>>
: std::true_type
{ };

} // namespace ttg::detail

#endif // TTG_SERIALIZATION_SUPPORTS_MADNESS
Expand Down
4 changes: 3 additions & 1 deletion ttg/ttg/parsec/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,11 @@ struct Buffer {
}
#endif // TTG_SERIALIZATION_SUPPORTS_MADNESS


};


static_assert(madness::is_serializable_v<madness::archive::BufferInspectorArchive<ttg::detail::buffer_apply_dummy_fn>, const Buffer<double, std::allocator<double>>&>);

namespace detail {
template<typename T, typename A>
parsec_data_t* get_parsec_data(const ttg_parsec::Buffer<T, A>& db) {
Expand Down
63 changes: 34 additions & 29 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ namespace ttg_parsec {

template<typename T>
inline void transfer_ownership_impl(T&& arg, int device) {
if constexpr(!std::is_const_v<std::remove_reference_t<T>>) {
if constexpr(!std::is_const_v<std::remove_reference_t<T>> && ttg::detail::has_buffer_apply_v<T>) {
ttg::detail::buffer_apply(arg, [&](auto&& buffer){
auto *data = detail::get_parsec_data(buffer);
parsec_data_transfer_ownership_to_copy(data, device, PARSEC_FLOW_ACCESS_RW);
Expand Down Expand Up @@ -3357,36 +3357,41 @@ namespace ttg_parsec {
template<typename Value>
void copy_mark_pushout(const Value& value) {

assert(detail::parsec_ttg_caller->dev_ptr && detail::parsec_ttg_caller->dev_ptr->gpu_task);
parsec_gpu_task_t *gpu_task = detail::parsec_ttg_caller->dev_ptr->gpu_task;
auto check_parsec_data = [&](parsec_data_t* data) {
if (data->owner_device != 0) {
/* find the flow */
int flowidx = 0;
while (flowidx < MAX_PARAM_COUNT &&
gpu_task->flow[flowidx]->flow_flags != PARSEC_FLOW_ACCESS_NONE) {
if (detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in->original == data) {
/* found the right data, set the corresponding flow as pushout */
break;
if constexpr (ttg::detail::has_buffer_apply_v<Value>) {
assert(detail::parsec_ttg_caller->dev_ptr && detail::parsec_ttg_caller->dev_ptr->gpu_task);
parsec_gpu_task_t *gpu_task = detail::parsec_ttg_caller->dev_ptr->gpu_task;
auto check_parsec_data = [&](parsec_data_t* data) {
if (data->owner_device != 0) {
/* find the flow */
int flowidx = 0;
while (flowidx < MAX_PARAM_COUNT &&
gpu_task->flow[flowidx]->flow_flags != PARSEC_FLOW_ACCESS_NONE) {
if (detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in->original == data) {
/* found the right data, set the corresponding flow as pushout */
break;
}
++flowidx;
}
++flowidx;
}
if (flowidx == MAX_PARAM_COUNT) {
throw std::runtime_error("Cannot add more than MAX_PARAM_COUNT flows to a task!");
}
if (gpu_task->flow[flowidx]->flow_flags == PARSEC_FLOW_ACCESS_NONE) {
/* no flow found, add one and mark it pushout */
detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in = data->device_copies[0];
gpu_task->flow_nb_elts[flowidx] = data->nb_elts;
if (flowidx == MAX_PARAM_COUNT) {
throw std::runtime_error("Cannot add more than MAX_PARAM_COUNT flows to a task!");
}
if (gpu_task->flow[flowidx]->flow_flags == PARSEC_FLOW_ACCESS_NONE) {
/* no flow found, add one and mark it pushout */
detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in = data->device_copies[0];
gpu_task->flow_nb_elts[flowidx] = data->nb_elts;
}
/* need to mark the flow RW to make PaRSEC happy */
((parsec_flow_t *)gpu_task->flow[flowidx])->flow_flags |= PARSEC_FLOW_ACCESS_RW;
gpu_task->pushout |= 1<<flowidx;
}
/* need to mark the flow RW to make PaRSEC happy */
((parsec_flow_t *)gpu_task->flow[flowidx])->flow_flags |= PARSEC_FLOW_ACCESS_RW;
gpu_task->pushout |= 1<<flowidx;
}
};
ttg::detail::buffer_apply(value, [&]<typename T, typename Allocator>(const ttg::Buffer<T, Allocator>& buffer){
check_parsec_data(detail::get_parsec_data(buffer));
});
};
ttg::detail::buffer_apply(value,
[&]<typename T, typename Allocator>(const ttg::Buffer<T, Allocator>& buffer){
check_parsec_data(detail::get_parsec_data(buffer));
});
} else {
throw std::runtime_error("Value type must be serializable with ttg::BufferInspectorArchive");
}
}


Expand Down

0 comments on commit 9ced4aa

Please sign in to comment.