Skip to content

Commit

Permalink
make_tt: fixes for passing empty tuple to callback
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Oct 30, 2024
1 parent fb55a37 commit 70a1f0e
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class CallableWrapTTArgs
}
};
auto unpack_input_tuple_if_needed = [&](auto&&... args){
static_assert(!funcT_receives_input_tuple);
if constexpr (funcT_receives_input_tuple) {
return invoke_func_handle_ret(std::forward<Tuple>(args_tuple), std::forward<decltype(args)>(args)...);
} else {
Expand Down Expand Up @@ -179,6 +178,14 @@ class CallableWrapTTArgs
}
};

auto invoke_func_empty_tuple = [&](auto&&... args){
if constexpr(funcT_receives_input_tuple) {
invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
} else {
invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
}
};

if constexpr (funcT_receives_outterm_tuple) {
invoke_func_handle_ret(out);
} else {
Expand All @@ -204,8 +211,16 @@ class CallableWrapTTArgs
}
};

auto invoke_func_empty_tuple = [&](auto&&... args){
if constexpr(funcT_receives_input_tuple) {
invoke_func_handle_ret(std::tuple<>{}, std::forward<decltype(args)>(args)...);
} else {
invoke_func_handle_ret(std::forward<decltype(args)>(args)...);
}
};

if constexpr (funcT_receives_outterm_tuple) {
return invoke_func_handle_ret(out);
return invoke_func_empty_tuple(out);
} else {
auto old_output_tls_ptr = this->outputs_tls_ptr_accessor();
this->set_outputs_tls_ptr();
Expand All @@ -214,7 +229,7 @@ class CallableWrapTTArgs
[this, old_output_tls_ptr](){
this->set_outputs_tls_ptr(old_output_tls_ptr);
});
return invoke_func_handle_ret();
return invoke_func_empty_tuple();
}
}

Expand Down Expand Up @@ -406,7 +421,7 @@ auto make_tt_tpl(funcT &&func, const std::tuple<ttg::Edge<keyT, input_edge_value
"ref; this is illegal, should only pass arguments as const lvalue ref or (nonconst) rvalue ref");
using input_args_t = std::decay_t<nondecayed_input_args_t>;
using decayed_input_args_t = ttg::meta::decayed_typelist_t<input_args_t>;
using wrapT = typename CallableWrapTTArgsAsTypelist<funcT, func_return_t, false, have_outterm_tuple, space, keyT,
using wrapT = typename CallableWrapTTArgsAsTypelist<funcT, func_return_t, true, have_outterm_tuple, space, keyT,
output_terminals_type, input_args_t>::type;
static_assert(std::is_same_v<decayed_input_args_t, std::tuple<input_edge_valuesT...>>,
"ttg::make_tt_tpl(func, inedges, outedges): inedges value types do not match argument types of func");
Expand All @@ -425,7 +440,7 @@ auto make_tt_tpl(funcT &&func, const std::tuple<ttg::Edge<keyT, input_edge_value
const std::vector<std::string> &outnames = std::vector<std::string>(sizeof...(output_edgesT),
"output"))
{
return make_tt_tpl<ttg::ExecutionSpace::Host, keyT, funcT, input_edge_valuesT..., output_edgesT...>(
return make_tt_tpl<ttg::ExecutionSpace::Host, keyT>(
std::forward<funcT>(func), inedges, outedges, name, innames, outnames);
}
// clang-format off
Expand Down

0 comments on commit 70a1f0e

Please sign in to comment.