diff --git a/examples/madness/mrattg.cc b/examples/madness/mrattg.cc index 2341ba864..8fcc2c6a9 100644 --- a/examples/madness/mrattg.cc +++ b/examples/madness/mrattg.cc @@ -177,7 +177,8 @@ namespace detail { using compress_out_type = std::tuple>; using compress_in_type = std::tuple; template - using compwrap_type = ttg::CallableWrapTT, compress_out_type, Rin, Rin>; + using compwrap_type = ttg::CallableWrapTT, compress_out_type, Rin, Rin>; }; template @@ -187,7 +188,8 @@ namespace detail { using compress_out_type = std::tuple>; using compress_in_type = std::tuple; template - using compwrap_type = ttg::CallableWrapTT, compress_out_type, Rin, Rin, Rin, Rin>; + using compwrap_type = ttg::CallableWrapTT, compress_out_type, Rin, Rin, Rin, Rin>; }; template @@ -198,7 +200,8 @@ namespace detail { using compress_in_type = std::tuple; template using compwrap_type = - ttg::CallableWrapTT, compress_out_type, Rin, Rin, Rin, Rin, Rin, Rin, Rin, Rin>; + ttg::CallableWrapTT, compress_out_type, Rin, Rin, Rin, Rin, Rin, Rin, Rin, Rin>; }; }; // namespace detail @@ -277,7 +280,8 @@ auto make_compress(rnodeEdge& in, cnodeEdge& out, const using sendfuncT = decltype(&send_leaves_up); using sendwrapT = - ttg::CallableWrapTT, typename ::detail::tree_types::compress_out_type, + ttg::CallableWrapTT, typename ::detail::tree_types::compress_out_type, FunctionReconstructedNode>; using compfuncT = decltype(&do_compress); using compwrapT = typename ::detail::tree_types::template compwrap_type; diff --git a/ttg/CMakeLists.txt b/ttg/CMakeLists.txt index b1fa72947..272f005ab 100644 --- a/ttg/CMakeLists.txt +++ b/ttg/CMakeLists.txt @@ -22,6 +22,7 @@ set(ttg-util-headers ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/meta.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/meta/callable.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/print.h + ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/scope_exit.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/span.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/trace.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/util/tree.h diff --git a/ttg/ttg/madness/ttg.h b/ttg/ttg/madness/ttg.h index 5d2360cfb..85dee437f 100644 --- a/ttg/ttg/madness/ttg.h +++ b/ttg/ttg/madness/ttg.h @@ -22,6 +22,7 @@ #include "ttg/util/macro.h" #include "ttg/util/meta.h" #include "ttg/util/meta/callable.h" +#include "ttg/util/scope_exit.h" #include "ttg/util/void.h" #include "ttg/world.h" #include "ttg/coroutine.h" diff --git a/ttg/ttg/make_tt.h b/ttg/ttg/make_tt.h index 81897b816..d3912576a 100644 --- a/ttg/ttg/make_tt.h +++ b/ttg/ttg/make_tt.h @@ -3,127 +3,6 @@ #ifndef TTG_MAKE_TT_H #define TTG_MAKE_TT_H -// Class to wrap a callable with signature -// -// case 1 (keyT != void): void op(auto&& key, std::tuple&&, std::tuple&) -// case 2 (keyT == void): void op(std::tuple&&, std::tuple&) -// -template -class CallableWrapTT - : public TT, - ttg::typelist> { - using baseT = typename CallableWrapTT::ttT; - - using input_values_tuple_type = typename baseT::input_values_tuple_type; - using input_refs_tuple_type = typename baseT::input_refs_tuple_type; - using input_edges_type = typename baseT::input_edges_type; - using output_edges_type = typename baseT::output_edges_type; - - using noref_funcT = std::remove_reference_t; - std::conditional_t, std::add_pointer_t, noref_funcT> func; - - template - void call_func(Key &&key, Tuple &&args, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::forward(key), std::forward(args), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::forward(key), std::forward(args)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - template - void call_func(TupleOrKey &&args, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::forward(args), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::forward(args)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - void call_func(output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) - func(std::tuple<>(), out); - else { - auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); - this->set_outputs_tls_ptr(); - func(std::tuple<>()); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } - } - - public: - template - CallableWrapTT(funcT_ &&f, const input_edges_type &inedges, const output_edges_type &outedges, - const std::string &name, const std::vector &innames, - const std::vector &outnames) - : baseT(inedges, outedges, name, innames, outnames), func(std::forward(f)) {} - - template - CallableWrapTT(funcT_ &&f, const std::string &name, const std::vector &innames, - const std::vector &outnames) - : baseT(name, innames, outnames), func(std::forward(f)) {} - - template - std::enable_if_t && !ttg::meta::is_empty_tuple_v && - !ttg::meta::is_void_v, - void> - op(Key &&key, ArgsTuple &&args_tuple, output_terminalsT &out) { - call_func(std::forward(key), std::forward(args_tuple), out); - } - - template - std::enable_if_t && !ttg::meta::is_empty_tuple_v && - ttg::meta::is_void_v, - void> - op(ArgsTuple &&args_tuple, output_terminalsT &out) { - call_func(std::forward(args_tuple), out); - } - - template - std::enable_if_t && !ttg::meta::is_void_v, void> op( - Key &&key, output_terminalsT &out) { - call_func(std::forward(key), out); - } - - template - std::enable_if_t && ttg::meta::is_void_v, void> op( - output_terminalsT &out) { - call_func(out); - } -}; - -template -struct CallableWrapTTUnwrapTypelist; - -template -struct CallableWrapTTUnwrapTypelist> { - using type = CallableWrapTT...>; -}; - -template -struct CallableWrapTTUnwrapTypelist> { - using type = CallableWrapTT...>; -}; // Class to wrap a callable with signature // @@ -131,14 +10,15 @@ struct CallableWrapTTUnwrapTypelist&) // // returnT is void for funcT = synchronous (ordinary) function and the appropriate return type for funcT=coroutine -template -class CallableWrapTTArgs +class CallableWrapTT : public TT< keyT, output_terminalsT, - CallableWrapTTArgs, + CallableWrapTT, ttg::typelist> { - using baseT = typename CallableWrapTTArgs::ttT; + using baseT = typename CallableWrapTT::ttT; using input_values_tuple_type = typename baseT::input_values_tuple_type; using input_refs_tuple_type = typename baseT::input_refs_tuple_type; @@ -174,7 +54,7 @@ class CallableWrapTTArgs template auto process_return(ReturnT&& ret, output_terminalsT &out) { static_assert(std::is_same_v, returnT>, - "CallableWrapTTArgs: returnT does not match the actual return type of funcT"); + "CallableWrapTT: returnT does not match the actual return type of funcT"); if constexpr (!std::is_void_v) { // protect from compiling for void returnT #ifdef TTG_HAVE_COROUTINE if constexpr (std::is_same_v) { @@ -203,10 +83,10 @@ class CallableWrapTTArgs #endif { static_assert(std::tuple_size_v> == 1, - "CallableWrapTTArgs <= 2, - "CallableWrapTTArgs == 0) @@ -224,104 +104,132 @@ class CallableWrapTTArgs template auto call_func(Key &&key, Tuple &&args_tuple, output_terminalsT &out, std::index_sequence) { using func_args_t = ttg::meta::tuple_concat_t, input_refs_tuple_type, output_edges_type>; - - if constexpr (funcT_receives_outterm_tuple) { + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))..., out); - return; + func(std::forward(key), std::forward(args)...); } else { - auto ret = func( - std::forward(key), - baseT::template get>(std::forward(args_tuple))..., out); - - return process_return(std::move(ret), out); + return process_return(func(std::forward(key), std::forward(args)...), out); } + }; + auto unpack_input_tuple_if_needed = [&](auto&&... args){ + if constexpr (funcT_receives_input_tuple) { + return invoke_func_handle_ret(std::forward(args_tuple), std::forward(args)...); + } else { + return invoke_func_handle_ret(baseT::template get>(std::forward(args_tuple))..., + std::forward(args)...); + } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return unpack_input_tuple_if_needed(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return; - } else { - auto ret = - func(std::forward(key), - baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::scope_exit( + [this, old_output_tls_ptr](){ + this->set_outputs_tls_ptr(old_output_tls_ptr); + }); + return unpack_input_tuple_if_needed(); } } template auto call_func(Tuple &&args_tuple, output_terminalsT &out, std::index_sequence) { using func_args_t = ttg::meta::tuple_concat_t; - if constexpr (funcT_receives_outterm_tuple) { + + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(baseT::template get>(std::forward(args_tuple))..., out); + func(std::forward(args)...); } else { - auto ret = func(baseT::template get>(std::forward(args_tuple))..., out); - return process_return(std::move(ret), out); + return process_return(func(std::forward(args)...), out); } + }; + auto unpack_input_tuple_if_needed = [&](auto&& fn, auto&&... args){ + if constexpr (funcT_receives_input_tuple) { + return fn(std::forward(args_tuple), std::forward(args)...); + } else { + return fn(baseT::template get>(std::forward(args_tuple))..., + std::forward(args)...); + } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return unpack_input_tuple_if_needed(invoke_func_handle_ret, out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(baseT::template get>(std::forward(args_tuple))...); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::scope_exit( + [this, old_output_tls_ptr](){ + this->set_outputs_tls_ptr(old_output_tls_ptr); + }); + return unpack_input_tuple_if_needed(invoke_func_handle_ret); } } template auto call_func(Key &&key, output_terminalsT &out) { - if constexpr (funcT_receives_outterm_tuple) { + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(std::forward(key), out); + func(std::forward(key), std::forward(args)...); + } else { + return process_return(func(std::forward(key), std::forward(args)...), out); + } + }; + + auto invoke_func_empty_tuple = [&](auto&&... args){ + if constexpr(funcT_receives_input_tuple) { + invoke_func_handle_ret(std::tuple<>{}, std::forward(args)...); } else { - auto ret = func(std::forward(key), out); - return process_return(std::move(ret), out); + invoke_func_handle_ret(std::forward(args)...); } + }; + + if constexpr (funcT_receives_outterm_tuple) { + invoke_func_handle_ret(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(std::forward(key)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(std::forward(key)); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::scope_exit( + [this, old_output_tls_ptr](){ + this->set_outputs_tls_ptr(old_output_tls_ptr); + }); + return invoke_func_handle_ret(); } } template auto call_func(OutputTerminals &out) { - if constexpr (funcT_receives_outterm_tuple) { + + auto invoke_func_handle_ret = [&](auto&&... args){ if constexpr (std::is_void_v) { - func(out); + func(std::forward(args)...); } else { - auto ret = func(out); - return process_return(std::move(ret), out); + return process_return(func(std::forward(args)...), out); } + }; + + auto invoke_func_empty_tuple = [&](auto&&... args){ + if constexpr(funcT_receives_input_tuple) { + invoke_func_handle_ret(std::tuple<>{}, std::forward(args)...); + } else { + invoke_func_handle_ret(std::forward(args)...); + } + }; + + if constexpr (funcT_receives_outterm_tuple) { + return invoke_func_empty_tuple(out); } else { auto old_output_tls_ptr = this->outputs_tls_ptr_accessor(); this->set_outputs_tls_ptr(); - if constexpr (std::is_void_v) { - func(); - this->set_outputs_tls_ptr(old_output_tls_ptr); - } else { - auto ret = func(out); - this->set_outputs_tls_ptr(old_output_tls_ptr); - return process_return(std::move(ret), out); - } + // make sure the output tls is reset + auto _ = ttg::detail::scope_exit( + [this, old_output_tls_ptr](){ + this->set_outputs_tls_ptr(old_output_tls_ptr); + }); + return invoke_func_empty_tuple(); } } @@ -333,13 +241,13 @@ class CallableWrapTTArgs public: template - CallableWrapTTArgs(funcT_ &&f, const input_edges_type &inedges, const typename baseT::output_edges_type &outedges, + CallableWrapTT(funcT_ &&f, const input_edges_type &inedges, const typename baseT::output_edges_type &outedges, const std::string &name, const std::vector &innames, const std::vector &outnames) : baseT(inedges, outedges, name, innames, outnames), func(std::forward(f)) {} template - CallableWrapTTArgs(funcT_ &&f, const std::string &name, const std::vector &innames, + CallableWrapTT(funcT_ &&f, const std::string &name, const std::vector &innames, const std::vector &outnames) : baseT(name, innames, outnames), func(std::forward(f)) {} @@ -378,23 +286,30 @@ class CallableWrapTTArgs }; }; -template -struct CallableWrapTTArgsAsTypelist; +struct CallableWrapTTAsTypelist; -template -struct CallableWrapTTArgsAsTypelist> { - using type = CallableWrapTTArgs...>; }; -template -struct CallableWrapTTArgsAsTypelist> { - using type = CallableWrapTTArgs...>; }; @@ -425,7 +340,11 @@ struct CallableWrapTTArgsAsTypelist +template auto make_tt_tpl(funcT &&func, const std::tuple...> &inedges = std::tuple<>{}, const std::tuple &outedges = std::tuple<>{}, const std::string &name = "wrapper", const std::vector &innames = std::vector(sizeof...(input_edge_valuesT), @@ -446,14 +365,17 @@ auto make_tt_tpl(funcT &&func, const std::tuple>, ttg::meta::candidate_argument_bindings_t< std::tuple::value_type>...>>, - ttg::meta::typelist>; + ttg::meta::typelist>; // net list of candidate argument types excludes the empty typelists for void arguments using candidate_func_args_t = ttg::meta::filter_t; // compute list of argument types with which func can be invoked constexpr static auto func_is_generic = ttg::meta::is_generic_callable_v; - using gross_func_args_t = decltype(ttg::meta::compute_arg_binding_types_r(func, candidate_func_args_t{})); + using return_type_typelist_and_gross_func_args_t = + decltype(ttg::meta::compute_arg_binding_types(func, candidate_func_args_t{})); + using func_return_t = std::tuple_element_t<0, std::tuple_element_t<0, return_type_typelist_and_gross_func_args_t>>; + using gross_func_args_t = std::tuple_element_t<1, return_type_typelist_and_gross_func_args_t>; constexpr auto DETECTED_HOW_TO_INVOKE_GENERIC_FUNC = func_is_generic ? !std::is_same_v> : true; static_assert(DETECTED_HOW_TO_INVOKE_GENERIC_FUNC, @@ -499,14 +421,28 @@ auto make_tt_tpl(funcT &&func, const std::tuple; using decayed_input_args_t = ttg::meta::decayed_typelist_t; - using wrapT = - typename CallableWrapTTUnwrapTypelist::type; + using wrapT = typename CallableWrapTTAsTypelist::type; static_assert(std::is_same_v>, "ttg::make_tt_tpl(func, inedges, outedges): inedges value types do not match argument types of func"); return std::make_unique(std::forward(func), inedges, outedges, name, innames, outnames); } +template +auto make_tt_tpl(funcT &&func, const std::tuple...> &inedges = std::tuple<>{}, + const std::tuple &outedges = std::tuple<>{}, const std::string &name = "wrapper", + const std::vector &innames = std::vector(sizeof...(input_edge_valuesT), + "input"), + const std::vector &outnames = std::vector(sizeof...(output_edgesT), + "output")) +{ + return make_tt_tpl( + std::forward(func), inedges, outedges, name, innames, outnames); +} // clang-format off /// @brief Factory function to assist in wrapping a callable with signature /// @@ -616,7 +552,7 @@ auto make_tt(funcT &&func, const std::tuple. OUTTERM_TUPLE_PASSED_AS_NONCONST_LVALUE_REF, "ttg::make_tt(func, ...): if given to func, the output terminal tuple must be passed by nonconst lvalue ref"); - // TT needs actual types of arguments to func ... extract them and pass to CallableWrapTTArgs + // TT needs actual types of arguments to func ... extract them and pass to CallableWrapTT using input_edge_value_types = ttg::meta::typelist...>; // input_args_t = {input_valuesT&&...} using input_args_t = typename ttg::meta::take_first_n< @@ -631,7 +567,7 @@ auto make_tt(funcT &&func, const std::tuple. using decayed_input_args_t = ttg::meta::decayed_typelist_t; // 3. full_input_args_t = edge-types with non-void types replaced by input_args_t using full_input_args_t = ttg::meta::replace_nonvoid_t; - using wrapT = typename CallableWrapTTArgsAsTypelist::type; return std::make_unique(std::forward(func), inedges, outedges, name, innames, outnames); diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 0a0ddefcb..2fc62bd91 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -33,6 +33,7 @@ #include "ttg/util/meta.h" #include "ttg/util/meta/callable.h" #include "ttg/util/print.h" +#include "ttg/util/scope_exit.h" #include "ttg/util/trace.h" #include "ttg/util/typelist.h" #ifdef TTG_HAVE_DEVICE diff --git a/ttg/ttg/util/scope_exit.h b/ttg/ttg/util/scope_exit.h new file mode 100644 index 000000000..5e5fe6722 --- /dev/null +++ b/ttg/ttg/util/scope_exit.h @@ -0,0 +1,58 @@ +#ifndef TTG_UTIL_SCOPE_EXIT_H +#define TTG_UTIL_SCOPE_EXIT_H + +// +// N4189: Scoped Resource - Generic RAII Wrapper for the Standard Library +// Peter Sommerlad and Andrew L. Sandoval +// Adopted from https://github.com/tandasat/ScopedResource/tree/master +// + +#include + +namespace ttg::detail { + template + struct scope_exit + { + // construction + explicit + scope_exit(EF &&f) + : exit_function(std::move(f)) + , execute_on_destruction{ true } + { } + + // move + scope_exit(scope_exit &&rhs) + : exit_function(std::move(rhs.exit_function)) + , execute_on_destruction{ rhs.execute_on_destruction } + { + rhs.release(); + } + + // release + ~scope_exit() + { + if (execute_on_destruction) this->exit_function(); + } + + void release() + { + this->execute_on_destruction = false; + } + + private: + scope_exit(scope_exit const &) = delete; + void operator=(scope_exit const &) = delete; + scope_exit& operator=(scope_exit &&) = delete; + EF exit_function; + bool execute_on_destruction; // exposition only + }; + + template + auto make_scope_exit(EF &&exit_function) + { + return scope_exit>(std::forward(exit_function)); + } + +} // namespace ttg::detail + +#endif // TTG_UTIL_SCOPE_EXIT_H \ No newline at end of file