Skip to content

Commit

Permalink
Customize get_return_address for _wsa_sender_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ispeters committed Sep 11, 2024
1 parent 14648df commit 8e2fce8
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions include/unifex/with_scheduler_affinity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <unifex/scheduler_concepts.hpp>
#include <unifex/sender_concepts.hpp>
#include <unifex/tag_invoke.hpp>
#include <unifex/tracing/async_stack.hpp>
#include <unifex/tracing/get_return_address.hpp>
#include <unifex/type_traits.hpp>
#include <unifex/unstoppable.hpp>

Expand Down Expand Up @@ -54,6 +56,7 @@ class _wsa_sender_wrapper<Sender, Scheduler>::type final {
decltype(_make_sender(UNIFEX_DECLVAL(Sender), UNIFEX_DECLVAL(Scheduler)));

sender_t sender_;
instruction_ptr returnAddress_;

public:
template <
Expand All @@ -73,11 +76,13 @@ class _wsa_sender_wrapper<Sender, Scheduler>::type final {
static constexpr bool is_always_scheduler_affine = true;

template <typename Sender2, typename Scheduler2>
type(Sender2&& sender, Scheduler2 scheduler) noexcept(noexcept(_make_sender(
static_cast<Sender2&&>(sender), static_cast<Scheduler2&&>(scheduler))))
: sender_(_make_sender(
type(Sender2&& sender, Scheduler2 scheduler, instruction_ptr returnAddress) noexcept(
noexcept(_make_sender(
static_cast<Sender2&&>(sender),
static_cast<Scheduler2&&>(scheduler))) {}
static_cast<Scheduler2&&>(scheduler))))
: sender_(_make_sender(
static_cast<Sender2&&>(sender), static_cast<Scheduler2&&>(scheduler)))
, returnAddress_(returnAddress) {}

template(typename Self, typename Receiver) //
(requires same_as<remove_cvref_t<Self>, type>) //
Expand All @@ -86,6 +91,11 @@ class _wsa_sender_wrapper<Sender, Scheduler>::type final {
return connect(
static_cast<Self&&>(sender).sender_, static_cast<Receiver&&>(receiver));
}

friend instruction_ptr
tag_invoke(tag_t<get_return_address>, const type& sender) noexcept {
return sender.returnAddress_;
}
};

struct _fn final {
Expand Down Expand Up @@ -128,7 +138,10 @@ struct _fn final {
using sender_t =
wsa_sender_wrapper<remove_cvref_t<Sender>, remove_cvref_t<Scheduler>>;

return sender_t{static_cast<Sender&&>(s), static_cast<Scheduler&&>(sched)};
return sender_t{
static_cast<Sender&&>(s),
static_cast<Scheduler&&>(sched),
instruction_ptr::read_return_address()};
}

template(typename Promise, typename Awaitable, typename Scheduler) //
Expand All @@ -144,7 +157,7 @@ struct _fn final {
if constexpr (
!same_as<blocking_kind, blocking_t> &&
(blocking_kind::always_inline == blocking_t{})) {
return Awaitable{(Awaitable &&) awaitable};
return Awaitable{(Awaitable&&)awaitable};
} else {
// TODO: do this more efficiently; the current approach converts an
// awaitable to a sender so we can pass it to via, only to
Expand Down

0 comments on commit 8e2fce8

Please sign in to comment.