Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feedback request: Future cancellation #192

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions support-lib/cpp/Future.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "expected.hpp"
#include "stop_token.hpp"

#include <atomic>
#include <functional>
Expand Down Expand Up @@ -60,6 +61,12 @@ struct BrokenPromiseException final : public std::exception {
}
};

struct CancelledFutureException final : public std::exception {
inline const char* what() const noexcept final {
return "djinni::Future was cancelled";
}
};

namespace detail {

// A wrapper object to support both void and non-void result types in
Expand Down Expand Up @@ -110,6 +117,20 @@ struct SharedState: ValueHolder<T> {
std::mutex mutex;
std::exception_ptr exception;
std::unique_ptr<ValueHandlerBase<T>> handler;

inplace_stop_source stopSource;
inplace_stop_token stopToken = stopSource.get_token();

struct ForwardingStopCallback final {
std::shared_ptr<inplace_stop_source> _stopSource;
void operator()() noexcept {
_stopSource->request_stop();
}

ForwardingStopCallback(std::shared_ptr<inplace_stop_source> stopSource)
:_stopSource{std::move(stopSource)} {}
};
std::optional<inplace_stop_callback<ForwardingStopCallback>> stopCallback{};

bool isReady() const {
return this->value.has_value() || exception != nullptr;
Expand Down Expand Up @@ -187,6 +208,13 @@ class PromiseBase {
return promise.getFuture();
}

[[nodiscard]] std::shared_ptr<inplace_stop_token> getStopToken() const noexcept {
return {
_sharedStateReadOnly,
&_sharedStateReadOnly->stopToken
};
}

protected:
// `setValue()` or `setException()` can only be called once. After which the
// shared state is set to null and further calls to `setValue()` or
Expand Down Expand Up @@ -330,6 +358,7 @@ class Future {
assert(sharedState); // a second call will trigger assertion
auto nextPromise = std::make_unique<Promise<HandlerReturnType>>();
auto nextFuture = nextPromise->getFuture();
forwardCancellationFrom(*nextPromise->getStopToken(), sharedState);
auto continuation = [handler = std::forward<FUNC>(handler), nextPromise = std::move(nextPromise)] (detail::SharedStatePtr<T> x) mutable {
try {
if constexpr(std::is_void_v<HandlerReturnType>) {
Expand Down Expand Up @@ -359,9 +388,24 @@ class Future {
return nextFuture;
}

[[nodiscard]] std::shared_ptr<inplace_stop_source> getStopSource() const noexcept {
return {
_sharedState,
&_sharedState->stopSource,
};
}

private:
detail::SharedStatePtr<T> _sharedState;

static void forwardCancellationFrom(inplace_stop_token stop_token, const detail::SharedStatePtr<T>& to) {
assert(!to->stopCallback); // future that already gets cancellations forwarded will trigger assertion
to->stopCallback.emplace(stop_token, std::shared_ptr<inplace_stop_source>{
to,
&to->stopSource,
});
}

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)
public:
bool await_ready() {
Expand All @@ -374,12 +418,15 @@ class Future {
sharedState = std::atomic_exchange(&_sharedState, sharedState);
return Future<T>(sharedState).get();
}
bool await_suspend(detail::CoroutineHandle<> h) {
this->then([h, this] (Future<T> x) mutable {
template<typename P>
void await_suspend(detail::CoroutineHandle<P> h) {
auto& promise = h.promise();

forwardCancellationFrom(*promise._promise.getStopToken(), _sharedState);
this->then([h, this](Future<T> x) mutable {
std::atomic_store(&_sharedState, x._sharedState);
h();
});
return true;
}

struct PromiseTypeBase {
Expand Down Expand Up @@ -442,6 +489,48 @@ struct Future<void>::PromiseType : PromiseTypeBase {
_result.emplace();
}
};

struct CheckCancelledT {
std::shared_ptr<inplace_stop_token> stop_token{};
constexpr bool await_ready() const noexcept {
return false;
}
template<typename P>
constexpr bool await_suspend(detail::CoroutineHandle<P> handle) const noexcept {
stop_token = handle.promise().getStopToken();
return false;
}
bool await_resume() const noexcept {
return stop_token->stop_requested();
}
};
inline CheckCancelledT check_cancelled() {
return {};
}

struct AbortIfCancelledT {
constexpr bool await_ready() const noexcept {
return false;
}
template<typename P>
constexpr bool await_suspend(detail::CoroutineHandle<P> suspended) const noexcept {
if (!suspended.promise()._promise.getStopToken()->stop_requested()) {
return false;
}

// Move Promise<T> out of the coroutine promise
auto promise { std::move(suspended.promise()._promise) };
// Destroy the coroutine state to destruct local variables etc
suspended.destroy();
// Then finalize the Promise<T> with an exception
promise.setException(CancelledFutureException{});
return true;
}
constexpr void await_resume() const noexcept {}
};
constexpr AbortIfCancelledT abort_if_cancelled() {
return {};
}
#endif

template <typename T>
Expand Down
105 changes: 105 additions & 0 deletions support-lib/cpp/stop_token.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include "stop_token.hpp"

#ifdef DJINNI_STOP_SOURCE_USE_CUSTOM_IMPL

namespace djinni {

bool inplace_stop_token::stop_requested() const noexcept {
return stop_possible() && _stop_source->stop_requested();
}

bool inplace_stop_token::stop_possible() const noexcept {
return _stop_source != nullptr;
}

void inplace_stop_token::swap(inplace_stop_token& other) noexcept {
std::swap(other._stop_source, _stop_source);
}


bool inplace_stop_source::stop_requested() const noexcept {
return _stopping_thread != std::thread::id{};
}

bool inplace_stop_source::request_stop() noexcept {
std::unique_lock lock{_mtx};
if (_stopping_thread != std::thread::id{}) {
return false;
}

_stopping_thread = std::this_thread::get_id();

for (; _cb_node != nullptr; _cb_node = _cb_node->next) {
lock.unlock();
(*_cb_node->callback)(_cb_node->arg);
lock.lock();
}
return true;
}

void inplace_stop_source::_register_callback(details::stop_callback_node* node) const noexcept {
std::unique_lock lock{_mtx};
if (_stopping_thread != std::thread::id{}) {
lock.unlock();
(*node->callback)(node->arg);
return;
}

node->next = _cb_node;
_cb_node = node->next;
}

void inplace_stop_source::_unregister_callback(details::stop_callback_node* node) const noexcept {
std::unique_lock lock{_mtx};
if (_stopping_thread != std::thread::id{} && _cb_node == node) {
// we're currently being stopped and the given callback is the one that is currently being executed
if (_stopping_thread == std::this_thread::get_id()) {
// the current callback is removing itself during its execution, so nothing to do
return;
} else {
// another thread is currently executing the callback, we have to block until it completes
// so we insert a phony callback that will alert us once the current callback is finished
bool finished{false};
std::condition_variable cv{};
auto finish = [this, &finished, &cv] {
{
std::lock_guard<std::mutex> lock{_mtx};
finished = true;
}
cv.notify_one();
};
using FinishFunctor = decltype(finish);

void(*finished_callback)(void*) = [](void* arg) {
auto& finish{*reinterpret_cast<FinishFunctor*>(arg)};
finish();
};


details::stop_callback_node node {
.callback = finished_callback,
.arg = &finish,
.next = _cb_node->next,
};
_cb_node->next = &node;

// Wait until the phony callback was invoked
cv.wait(lock, [&finished] {
return finished;
});

return;
}
}

for (auto* prev_node = _cb_node; prev_node->next != nullptr; prev_node = prev_node->next) {
if (prev_node->next == node) {
prev_node->next = node->next;
return;
}
}
}

}

#endif
Loading
Loading