Skip to content

Commit

Permalink
Protect parsec data buffer inspection against empty buffers
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Nov 19, 2024
1 parent 915e7d8 commit da4545c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 37 deletions.
5 changes: 4 additions & 1 deletion ttg/ttg/parsec/parsec_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace ttg_parsec::detail {
/* protect for non-serializable types, allowed if the TT has no device op */
if constexpr (ttg::detail::has_buffer_apply_v<Value>) {
ttg::detail::buffer_apply(value, [&]<typename B>(B&& b){
fn(detail::get_parsec_data(b));
parsec_data_t *data = detail::get_parsec_data(b);
if (nullptr != data) {
fn(data);
}
});
}
}
Expand Down
67 changes: 31 additions & 36 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,8 @@ namespace ttg_parsec {

template<typename T>
inline void transfer_ownership_impl(T&& arg, int device) {
if constexpr(!std::is_const_v<std::remove_reference_t<T>> && ttg::detail::has_buffer_apply_v<T>) {
ttg::detail::buffer_apply(arg, [&](auto&& buffer){
auto *data = detail::get_parsec_data(buffer);
if constexpr(!std::is_const_v<std::remove_reference_t<T>>) {
detail::foreach_parsec_data(arg, [&](parsec_data_t *data){
parsec_data_transfer_ownership_to_copy(data, device, PARSEC_FLOW_ACCESS_RW);
/* make sure we increment the version since we will modify the data */
data->device_copies[0]->version++;
Expand Down Expand Up @@ -3441,41 +3440,37 @@ namespace ttg_parsec {
template<typename Value>
void copy_mark_pushout(const Value& value) {

if constexpr (ttg::detail::has_buffer_apply_v<Value>) {
assert(detail::parsec_ttg_caller->dev_ptr && detail::parsec_ttg_caller->dev_ptr->gpu_task);
parsec_gpu_task_t *gpu_task = detail::parsec_ttg_caller->dev_ptr->gpu_task;
auto check_parsec_data = [&](parsec_data_t* data) {
if (data->owner_device != 0) {
/* find the flow */
int flowidx = 0;
while (flowidx < MAX_PARAM_COUNT &&
gpu_task->flow[flowidx]->flow_flags != PARSEC_FLOW_ACCESS_NONE) {
if (detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in->original == data) {
/* found the right data, set the corresponding flow as pushout */
break;
}
++flowidx;
}
if (flowidx == MAX_PARAM_COUNT) {
throw std::runtime_error("Cannot add more than MAX_PARAM_COUNT flows to a task!");
}
if (gpu_task->flow[flowidx]->flow_flags == PARSEC_FLOW_ACCESS_NONE) {
/* no flow found, add one and mark it pushout */
detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in = data->device_copies[0];
gpu_task->flow_nb_elts[flowidx] = data->nb_elts;
assert(detail::parsec_ttg_caller->dev_ptr && detail::parsec_ttg_caller->dev_ptr->gpu_task);
parsec_gpu_task_t *gpu_task = detail::parsec_ttg_caller->dev_ptr->gpu_task;
auto check_parsec_data = [&](parsec_data_t* data) {
if (data->owner_device != 0) {
/* find the flow */
int flowidx = 0;
while (flowidx < MAX_PARAM_COUNT &&
gpu_task->flow[flowidx]->flow_flags != PARSEC_FLOW_ACCESS_NONE) {
if (detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in->original == data) {
/* found the right data, set the corresponding flow as pushout */
break;
}
/* need to mark the flow RW to make PaRSEC happy */
((parsec_flow_t *)gpu_task->flow[flowidx])->flow_flags |= PARSEC_FLOW_ACCESS_RW;
gpu_task->pushout |= 1<<flowidx;
++flowidx;
}
};
ttg::detail::buffer_apply(value,
[&]<typename T, typename Allocator>(const ttg::Buffer<T, Allocator>& buffer){
check_parsec_data(detail::get_parsec_data(buffer));
});
} else {
throw std::runtime_error("Value type must be serializable with ttg::BufferVisitorArchive");
}
if (flowidx == MAX_PARAM_COUNT) {
throw std::runtime_error("Cannot add more than MAX_PARAM_COUNT flows to a task!");
}
if (gpu_task->flow[flowidx]->flow_flags == PARSEC_FLOW_ACCESS_NONE) {
/* no flow found, add one and mark it pushout */
detail::parsec_ttg_caller->parsec_task.data[flowidx].data_in = data->device_copies[0];
gpu_task->flow_nb_elts[flowidx] = data->nb_elts;
}
/* need to mark the flow RW to make PaRSEC happy */
((parsec_flow_t *)gpu_task->flow[flowidx])->flow_flags |= PARSEC_FLOW_ACCESS_RW;
gpu_task->pushout |= 1<<flowidx;
}
};
detail::foreach_parsec_data(value,
[&](parsec_data_t* data){
check_parsec_data(data);
});
}


Expand Down

0 comments on commit da4545c

Please sign in to comment.