diff --git a/support-lib/cpp/Future.hpp b/support-lib/cpp/Future.hpp index 6e94152e..6b289356 100644 --- a/support-lib/cpp/Future.hpp +++ b/support-lib/cpp/Future.hpp @@ -17,6 +17,7 @@ #pragma once #include "expected.hpp" +#include "stop_token.hpp" #include #include @@ -60,6 +61,12 @@ struct BrokenPromiseException final : public std::exception { } }; +struct CancelledFutureException final : public std::exception { + inline const char* what() const noexcept final { + return "djinni::Future was cancelled"; + } +}; + namespace detail { // A wrapper object to support both void and non-void result types in @@ -110,6 +117,20 @@ struct SharedState: ValueHolder { std::mutex mutex; std::exception_ptr exception; std::unique_ptr> handler; + + inplace_stop_source stopSource; + inplace_stop_token stopToken = stopSource.get_token(); + + struct ForwardingStopCallback final { + std::shared_ptr _stopSource; + void operator()() noexcept { + _stopSource->request_stop(); + } + + ForwardingStopCallback(std::shared_ptr stopSource) + :_stopSource{std::move(stopSource)} {} + }; + std::optional> stopCallback{}; bool isReady() const { return this->value.has_value() || exception != nullptr; @@ -187,6 +208,13 @@ class PromiseBase { return promise.getFuture(); } + [[nodiscard]] std::shared_ptr getStopToken() const noexcept { + return { + _sharedStateReadOnly, + &_sharedStateReadOnly->stopToken + }; + } + protected: // `setValue()` or `setException()` can only be called once. After which the // shared state is set to null and further calls to `setValue()` or @@ -330,6 +358,7 @@ class Future { assert(sharedState); // a second call will trigger assertion auto nextPromise = std::make_unique>(); auto nextFuture = nextPromise->getFuture(); + forwardCancellationFrom(*nextPromise->getStopToken(), sharedState); auto continuation = [handler = std::forward(handler), nextPromise = std::move(nextPromise)] (detail::SharedStatePtr x) mutable { try { if constexpr(std::is_void_v) { @@ -359,9 +388,24 @@ class Future { return nextFuture; } + [[nodiscard]] std::shared_ptr getStopSource() const noexcept { + return { + _sharedState, + &_sharedState->stopSource, + }; + } + private: detail::SharedStatePtr _sharedState; + static void forwardCancellationFrom(inplace_stop_token stop_token, const detail::SharedStatePtr& to) { + assert(!to->stopCallback); // future that already gets cancellations forwarded will trigger assertion + to->stopCallback.emplace(stop_token, std::shared_ptr{ + to, + &to->stopSource, + }); + } + #if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT) public: bool await_ready() { @@ -374,12 +418,15 @@ class Future { sharedState = std::atomic_exchange(&_sharedState, sharedState); return Future(sharedState).get(); } - bool await_suspend(detail::CoroutineHandle<> h) { - this->then([h, this] (Future x) mutable { + template + void await_suspend(detail::CoroutineHandle

h) { + auto& promise = h.promise(); + + forwardCancellationFrom(*promise._promise.getStopToken(), _sharedState); + this->then([h, this](Future x) mutable { std::atomic_store(&_sharedState, x._sharedState); h(); }); - return true; } struct PromiseTypeBase { @@ -442,6 +489,48 @@ struct Future::PromiseType : PromiseTypeBase { _result.emplace(); } }; + +struct CheckCancelledT { + std::shared_ptr stop_token{}; + constexpr bool await_ready() const noexcept { + return false; + } + template + constexpr bool await_suspend(detail::CoroutineHandle

handle) const noexcept { + stop_token = handle.promise().getStopToken(); + return false; + } + bool await_resume() const noexcept { + return stop_token->stop_requested(); + } +}; +inline CheckCancelledT check_cancelled() { + return {}; +} + +struct AbortIfCancelledT { + constexpr bool await_ready() const noexcept { + return false; + } + template + constexpr bool await_suspend(detail::CoroutineHandle

suspended) const noexcept { + if (!suspended.promise()._promise.getStopToken()->stop_requested()) { + return false; + } + + // Move Promise out of the coroutine promise + auto promise { std::move(suspended.promise()._promise) }; + // Destroy the coroutine state to destruct local variables etc + suspended.destroy(); + // Then finalize the Promise with an exception + promise.setException(CancelledFutureException{}); + return true; + } + constexpr void await_resume() const noexcept {} +}; +constexpr AbortIfCancelledT abort_if_cancelled() { + return {}; +} #endif template diff --git a/support-lib/cpp/stop_token.cpp b/support-lib/cpp/stop_token.cpp new file mode 100644 index 00000000..8abe0e06 --- /dev/null +++ b/support-lib/cpp/stop_token.cpp @@ -0,0 +1,105 @@ +#include "stop_token.hpp" + +#ifdef DJINNI_STOP_SOURCE_USE_CUSTOM_IMPL + +namespace djinni { + +bool inplace_stop_token::stop_requested() const noexcept { + return stop_possible() && _stop_source->stop_requested(); +} + +bool inplace_stop_token::stop_possible() const noexcept { + return _stop_source != nullptr; +} + +void inplace_stop_token::swap(inplace_stop_token& other) noexcept { + std::swap(other._stop_source, _stop_source); +} + + +bool inplace_stop_source::stop_requested() const noexcept { + return _stopping_thread != std::thread::id{}; +} + +bool inplace_stop_source::request_stop() noexcept { + std::unique_lock lock{_mtx}; + if (_stopping_thread != std::thread::id{}) { + return false; + } + + _stopping_thread = std::this_thread::get_id(); + + for (; _cb_node != nullptr; _cb_node = _cb_node->next) { + lock.unlock(); + (*_cb_node->callback)(_cb_node->arg); + lock.lock(); + } + return true; +} + +void inplace_stop_source::_register_callback(details::stop_callback_node* node) const noexcept { + std::unique_lock lock{_mtx}; + if (_stopping_thread != std::thread::id{}) { + lock.unlock(); + (*node->callback)(node->arg); + return; + } + + node->next = _cb_node; + _cb_node = node->next; +} + +void inplace_stop_source::_unregister_callback(details::stop_callback_node* node) const noexcept { + std::unique_lock lock{_mtx}; + if (_stopping_thread != std::thread::id{} && _cb_node == node) { + // we're currently being stopped and the given callback is the one that is currently being executed + if (_stopping_thread == std::this_thread::get_id()) { + // the current callback is removing itself during its execution, so nothing to do + return; + } else { + // another thread is currently executing the callback, we have to block until it completes + // so we insert a phony callback that will alert us once the current callback is finished + bool finished{false}; + std::condition_variable cv{}; + auto finish = [this, &finished, &cv] { + { + std::lock_guard lock{_mtx}; + finished = true; + } + cv.notify_one(); + }; + using FinishFunctor = decltype(finish); + + void(*finished_callback)(void*) = [](void* arg) { + auto& finish{*reinterpret_cast(arg)}; + finish(); + }; + + + details::stop_callback_node node { + .callback = finished_callback, + .arg = &finish, + .next = _cb_node->next, + }; + _cb_node->next = &node; + + // Wait until the phony callback was invoked + cv.wait(lock, [&finished] { + return finished; + }); + + return; + } + } + + for (auto* prev_node = _cb_node; prev_node->next != nullptr; prev_node = prev_node->next) { + if (prev_node->next == node) { + prev_node->next = node->next; + return; + } + } +} + +} + +#endif diff --git a/support-lib/cpp/stop_token.hpp b/support-lib/cpp/stop_token.hpp new file mode 100644 index 00000000..1170cd09 --- /dev/null +++ b/support-lib/cpp/stop_token.hpp @@ -0,0 +1,155 @@ +#pragma once + +#ifdef __has_include +#if __has_include() +#include +#endif +#endif + +// At the time of writing, P2300R10 only defines __cpp_lib_senders as a feature-test, nothing specific to whether inplace_stop_token is available +#ifndef __cpp_lib_senders +#ifndef DJINNI_STOP_SOURCE_USE_CUSTOM_IMPL +#define DJINNI_STOP_SOURCE_USE_CUSTOM_IMPL +#endif +#endif + +#ifndef DJINNI_STOP_SOURCE_USE_CUSTOM_IMPL + +// Standard implementation is available + +#include + +namespace djinni { + +using inplace_stop_token = std::inplace_stop_token; +using inplace_stop_source = std::inplace_stop_source; +template +using inplace_stop_callback = std::inplace_stop_callback; + +} + +#else + +// Standard implementation unavailable, we roll our own implementation of what we need + +#include +#include +#include + +namespace djinni { + +namespace details { + +struct stop_callback_node final { + void(*callback)(void*); + void* arg; + stop_callback_node* next; +}; + +}; + +class inplace_stop_source; +template +class inplace_stop_callback; + +class inplace_stop_token { +public: + template + using callback_type = inplace_stop_callback; + + inplace_stop_token() = default; + bool operator==(const inplace_stop_token&) const = default; + + [[nodiscard]] bool stop_requested() const noexcept; + [[nodiscard]] bool stop_possible() const noexcept; + void swap(inplace_stop_token&) noexcept; + +private: + template + friend class inplace_stop_callback; + friend inplace_stop_source; + constexpr inplace_stop_token(const inplace_stop_source* stop_source) noexcept + :_stop_source{stop_source} {} + + const inplace_stop_source* _stop_source = nullptr; +}; + +class inplace_stop_source { +public: + constexpr inplace_stop_source() noexcept = default; + + inplace_stop_source(inplace_stop_source&&) = delete; + inplace_stop_source(const inplace_stop_source&) = delete; + inplace_stop_source& operator=(inplace_stop_source&&) = delete; + inplace_stop_source& operator=(const inplace_stop_source&) = delete; + + [[nodiscard]] constexpr inplace_stop_token get_token() const noexcept { + return {this}; + } + [[nodiscard]] static constexpr bool stop_possible() noexcept { return true; } + + [[nodiscard]] bool stop_requested() const noexcept; + bool request_stop() noexcept; + +private: + template + friend class inplace_stop_callback; + void _register_callback(details::stop_callback_node* callback) const noexcept; + void _unregister_callback(details::stop_callback_node* callback) const noexcept; + + mutable std::mutex _mtx{}; + std::thread::id _stopping_thread{}; + mutable details::stop_callback_node* _cb_node = nullptr; +}; + +template +class inplace_stop_callback { +public: + using callback_type = Callback; + + template + explicit inplace_stop_callback(inplace_stop_token token, C&& callback) noexcept(std::is_nothrow_constructible_v) + :_token(std::move(token)) + ,_callback(std::forward(callback)) + { + if (_token._stop_source) { + _token._stop_source->_register_callback(&_node); + } + } + + ~inplace_stop_callback() { + if (_token._stop_source) { + _token._stop_source->_unregister_callback(&_node); + } + } + + inplace_stop_callback(const inplace_stop_callback&) = delete; + inplace_stop_callback& operator=(const inplace_stop_callback&) = delete; + inplace_stop_callback(inplace_stop_callback&&) = delete; + inplace_stop_callback& operator=(inplace_stop_callback&&) = delete; + +private: + static void invoke(void* callback) { + reinterpret_cast(callback)->_callback(); + } + + inplace_stop_token _token; + Callback _callback; + details::stop_callback_node _node { + .callback = &inplace_stop_callback::invoke, + .arg = this, + .next = nullptr, + }; +}; + +template +inplace_stop_callback(inplace_stop_token, C callback) -> inplace_stop_callback; + +/* +Ignored the following std requirements for simplicity and C++17 compatibility: +- template arg of inplace_stop_callback must be std::invocable and std::destructible +*/ + +} + +#endif