From cb285cf82447705e7f34f75b06d54cb87df39b48 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Tue, 3 Sep 2024 18:15:10 -0400 Subject: [PATCH] Consolidate make_tt_tpl and make_tt We can map make_tt_tpl onto the same mechanism we use for make_tt by adding a template template parameter that signals whether arguments are passed as tuple or unpacked. Signed-off-by: Joseph Schuchart --- ttg/ttg/make_tt.h | 297 ++++++++++++++++------------------------------ 1 file changed, 103 insertions(+), 194 deletions(-) diff --git a/ttg/ttg/make_tt.h b/ttg/ttg/make_tt.h index 81897b816..7ba01b56d 100644 --- a/ttg/ttg/make_tt.h +++ b/ttg/ttg/make_tt.h @@ -3,127 +3,6 @@ #ifndef TTG_MAKE_TT_H #define TTG_MAKE_TT_H -// Class to wrap a callable with signature -// -// case 1 (keyT != void): void op(auto&& key, std::tuple&&, std::tuple&) -// case 2 (keyT == void): void op(std::tuple&&, std::tuple&) -// -template -class CallableWrapTT - : public TT, - ttg::typelist> { - using baseT = typename CallableWrapTT::ttT; - - using input_values_tuple_type = typename baseT::input_values_tuple_type; - using input_refs_tuple_type = typename baseT::input_refs_tuple_type; - using input_edges_type = typename baseT::input_edges_type; - using output_edges_type = typename baseT::output_edges_type; - - using noref_funcT = std::remove_reference_t; - std::conditional_t, std::add_pointer_t, noref_funcT> func; - - template - void call_func(Key &&key, Tuple &&args, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::forward(key), std::forward(args), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::forward(key), std::forward(args)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - template - void call_func(TupleOrKey &&args, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::forward(args), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::forward(args)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - void call_func(output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::tuple<>(), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::tuple<>()); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - public: - template - CallableWrapTT(funcT_ &&f, const input_edges_type &inedges, const output_edges_type &outedges, - const std::string &name, const std::vector &innames, - const std::vector &outnames) - : baseT(inedges, outedges, name, innames, outnames), func(std::forward(f)) {} - - template - CallableWrapTT(funcT_ &&f, const std::string &name, const std::vector &innames, - const std::vector &outnames) - : baseT(name, innames, outnames), func(std::forward(f)) {} - - template - std::enable_if_t && !ttg::meta::is_empty_tuple_v && - !ttg::meta::is_void_v, - void> - op(Key &&key, ArgsTuple &&args_tuple, output_terminalsT &out) { - call_func(std::forward(key), std::forward(args_tuple), out); - } - - template - std::enable_if_t && !ttg::meta::is_empty_tuple_v && - ttg::meta::is_void_v, - void> - op(ArgsTuple &&args_tuple, output_terminalsT &out) { - call_func(std::forward(args_tuple), out); - } - - template - std::enable_if_t && !ttg::meta::is_void_v, void> op( - Key &&key, output_terminalsT &out) { - call_func(std::forward(key), out); - } - - template - std::enable_if_t && ttg::meta::is_void_v, void> op( - output_terminalsT &out) { - call_func(out); - } -}; - -template -struct CallableWrapTTUnwrapTypelist; - -template -struct CallableWrapTTUnwrapTypelist> { - using type = CallableWrapTT...>; -}; - -template -struct CallableWrapTTUnwrapTypelist> { - using type = CallableWrapTT...>; -}; // Class to wrap a callable with signature // @@ -131,12 +10,13 @@ struct CallableWrapTTUnwrapTypelist&) // // returnT is void for funcT = synchronous (ordinary) function and the appropriate return type for funcT=coroutine -template class CallableWrapTTArgs : public TT< keyT, output_terminalsT, - CallableWrapTTArgs, + CallableWrapTTArgs, ttg::typelist> { using baseT = typename CallableWrapTTArgs::ttT; @@ -224,104 +104,105 @@ class CallableWrapTTArgs template auto call_func(Key &&key, Tuple &&args_tuple, output_terminalsT &out, std::index_sequence) { using func_args_t = ttg::meta::tuple_concat_t, input_refs_tuple_type, output_edges_type>; - - if constexpr (funcT_receives_outterm_tuple) { + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))..., out); - return; + func(std::forward(key), std::forward(args)...); } else { - auto ret = func( - std::forward(key), - baseT::template get>(std::forward(args_tuple))..., out); - - return process_return(std::move(ret), out); + return process_return(func(std::forward(key), std::forward(args)...), out); + } + }; + auto unpack_input_tuple_if_needed = [&](auto&&... args){ + static_assert(!funcT_receives_input_tuple); + if constexpr (funcT_receives_input_tuple) { + return invoke_func_handle_ret(std::forward(args_tuple), std::forward(args)...); + } else { + return invoke_func_handle_ret(baseT::template get>(std::forward(args_tuple))..., + std::forward(args)...); } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return unpack_input_tuple_if_needed(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return; - } else { - auto ret = - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::finally([this, old_output_tls_ptr](){ this->set_outputs_tls_ptr(old_output_tls_ptr); }); + return unpack_input_tuple_if_needed(); } } template auto call_func(Tuple &&args_tuple, output_terminalsT &out, std::index_sequence) { using func_args_t = ttg::meta::tuple_concat_t; - if constexpr (funcT_receives_outterm_tuple) { + + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(baseT::template get>(std::forward(args_tuple))..., out); + func(std::forward(args)...); + } else { + return process_return(func(std::forward(args)...), out); + } + }; + auto unpack_input_tuple_if_needed = [&](auto&& fn, auto&&... args){ + if constexpr (funcT_receives_input_tuple) { + return fn(std::forward(args_tuple), std::forward(args)...); } else { - auto ret = func(baseT::template get>(std::forward(args_tuple))..., out); - return process_return(std::move(ret), out); + return fn(baseT::template get>(std::forward(args_tuple))..., + std::forward(args)...); } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return unpack_input_tuple_if_needed(invoke_func_handle_ret, out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::finally([this, old_output_tls_ptr](){ this->set_outputs_tls_ptr(old_output_tls_ptr); }); + return unpack_input_tuple_if_needed(invoke_func_handle_ret); } } template auto call_func(Key &&key, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) { + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(std::forward(key), out); + func(std::forward(key), std::forward(args)...); } else { - auto ret = func(std::forward(key), out); - return process_return(std::move(ret), out); + return process_return(func(std::forward(key), std::forward(args)...), out); } + }; + + if constexpr (funcT_receives_outterm_tuple) { + invoke_func_handle_ret(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(std::forward(key)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(std::forward(key)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::finally([this, old_output_tls_ptr](){ this->set_outputs_tls_ptr(old_output_tls_ptr); }); + return invoke_func_handle_ret(); } } template auto call_func(OutputTerminals &out) { - if constexpr (funcT_receives_outterm_tuple) { + + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(out); + func(std::forward(args)...); } else { - auto ret = func(out); - return process_return(std::move(ret), out); + return process_return(func(std::forward(args)...), out); } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return invoke_func_handle_ret(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(out); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::finally([this, old_output_tls_ptr](){ this->set_outputs_tls_ptr(old_output_tls_ptr); }); + return invoke_func_handle_ret(); } } @@ -378,23 +259,30 @@ class CallableWrapTTArgs }; }; -template struct CallableWrapTTArgsAsTypelist; -template -struct CallableWrapTTArgsAsTypelist> { - using type = CallableWrapTTArgs...>; }; -template -struct CallableWrapTTArgsAsTypelist> { - using type = CallableWrapTTArgs...>; }; @@ -425,7 +313,11 @@ struct CallableWrapTTArgsAsTypelist +template auto make_tt_tpl(funcT &&func, const std::tuple...> &inedges = std::tuple<>{}, const std::tuple &outedges = std::tuple<>{}, const std::string &name = "wrapper", const std::vector &innames = std::vector(sizeof...(input_edge_valuesT), @@ -446,14 +338,17 @@ auto make_tt_tpl(funcT &&func, const std::tuple>, ttg::meta::candidate_argument_bindings_t< std::tuple::value_type>...>>, - ttg::meta::typelist>; + ttg::meta::typelist>; // net list of candidate argument types excludes the empty typelists for void arguments using candidate_func_args_t = ttg::meta::filter_t; // compute list of argument types with which func can be invoked constexpr static auto func_is_generic = ttg::meta::is_generic_callable_v; - using gross_func_args_t = decltype(ttg::meta::compute_arg_binding_types_r(func, candidate_func_args_t{})); + using return_type_typelist_and_gross_func_args_t = + decltype(ttg::meta::compute_arg_binding_types(func, candidate_func_args_t{})); + using func_return_t = std::tuple_element_t<0, std::tuple_element_t<0, return_type_typelist_and_gross_func_args_t>>; + using gross_func_args_t = std::tuple_element_t<1, return_type_typelist_and_gross_func_args_t>; constexpr auto DETECTED_HOW_TO_INVOKE_GENERIC_FUNC = func_is_generic ? !std::is_same_v> : true; static_assert(DETECTED_HOW_TO_INVOKE_GENERIC_FUNC, @@ -499,14 +394,28 @@ auto make_tt_tpl(funcT &&func, const std::tuple; using decayed_input_args_t = ttg::meta::decayed_typelist_t; - using wrapT = - typename CallableWrapTTUnwrapTypelist::type; + using wrapT = typename CallableWrapTTArgsAsTypelist::type; static_assert(std::is_same_v>, "ttg::make_tt_tpl(func, inedges, outedges): inedges value types do not match argument types of func"); return std::make_unique(std::forward(func), inedges, outedges, name, innames, outnames); } +template +auto make_tt_tpl(funcT &&func, const std::tuple...> &inedges = std::tuple<>{}, + const std::tuple &outedges = std::tuple<>{}, const std::string &name = "wrapper", + const std::vector &innames = std::vector(sizeof...(input_edge_valuesT), + "input"), + const std::vector &outnames = std::vector(sizeof...(output_edgesT), + "output")) +{ + return make_tt_tpl( + std::forward(func), inedges, outedges, name, innames, outnames); +} // clang-format off /// @brief Factory function to assist in wrapping a callable with signature /// @@ -631,7 +540,7 @@ auto make_tt(funcT &&func, const std::tuple. using decayed_input_args_t = ttg::meta::decayed_typelist_t; // 3. full_input_args_t = edge-types with non-void types replaced by input_args_t using full_input_args_t = ttg::meta::replace_nonvoid_t; - using wrapT = typename CallableWrapTTArgsAsTypelist::type; return std::make_unique(std::forward(func), inedges, outedges, name, innames, outnames);