Skip to content

Commit

Permalink
Implement cancellation forwarding for djinni::Future
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-gcx committed Oct 24, 2024
1 parent 02a9c78 commit c9906a3
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions support-lib/cpp/Future.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ struct SharedState: ValueHolder<T> {
inplace_stop_source stopSource;
inplace_stop_token stopToken = stopSource.get_token();

struct ForwardingStopCallback final {
std::shared_ptr<inplace_stop_source> _stopSource;
void operator()() noexcept {
_stopSource->request_stop();
}

ForwardingStopCallback(std::shared_ptr<inplace_stop_source> stopSource)
:_stopSource{std::move(stopSource)} {}
};
std::optional<inplace_stop_callback<ForwardingStopCallback>> stopCallback{};

bool isReady() const {
return this->value.has_value() || exception != nullptr;
}
Expand Down Expand Up @@ -347,6 +358,7 @@ class Future {
assert(sharedState); // a second call will trigger assertion
auto nextPromise = std::make_unique<Promise<HandlerReturnType>>();
auto nextFuture = nextPromise->getFuture();
forwardCancellationFrom(*nextPromise->getStopToken(), sharedState);
auto continuation = [handler = std::forward<FUNC>(handler), nextPromise = std::move(nextPromise)] (detail::SharedStatePtr<T> x) mutable {
try {
if constexpr(std::is_void_v<HandlerReturnType>) {
Expand Down Expand Up @@ -386,6 +398,14 @@ class Future {
private:
detail::SharedStatePtr<T> _sharedState;

static void forwardCancellationFrom(inplace_stop_token stop_token, const detail::SharedStatePtr<T>& to) {
assert(!to->stopCallback); // future that already gets cancellations forwarded will trigger assertion
to->stopCallback.emplace(stop_token, std::shared_ptr<inplace_stop_source>{
to,
&to->stopSource,
});
}

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)
public:
bool await_ready() {
Expand All @@ -398,8 +418,12 @@ class Future {
sharedState = std::atomic_exchange(&_sharedState, sharedState);
return Future<T>(sharedState).get();
}
void await_suspend(detail::CoroutineHandle<> h) {
this->then([h, this] (Future<T> x) mutable {
template<typename P>
void await_suspend(detail::CoroutineHandle<P> h) {
auto& promise = h.promise();

forwardCancellationFrom(*promise._promise.getStopToken(), _sharedState);
this->then([h, this](Future<T> x) mutable {
std::atomic_store(&_sharedState, x._sharedState);
h();
});
Expand Down

0 comments on commit c9906a3

Please sign in to comment.