diff --git a/include/seastar/coroutine/generator.hh b/include/seastar/coroutine/generator.hh index 9929debb5f..59a08eaf71 100644 --- a/include/seastar/coroutine/generator.hh +++ b/include/seastar/coroutine/generator.hh @@ -57,13 +57,16 @@ // Nesting generators is not supported. You cannot yield another generator // from within a generator. This prevents implementation asynchronous, // recursive algorithms like depth-first search on trees. +// * Range Yielding: +// supports directly yielding a range. We could instead implement nesting +// using a dedicated awaiter and sophisticated frame tracking for nested +// generators, but this would significantly increase implementation complexity. namespace seastar::coroutine::experimental { -template -class generator; - namespace internal { +namespace unbuffered { + template class next_awaiter; template @@ -234,9 +237,7 @@ public: } }; -} // namespace internal - -/// generator represents a view modeling std::ranges::input_range, +/// unbuffered generator represents a view modeling std::ranges::input_range, /// and has move-only iterators. /// /// generator has 2 template parameters: @@ -304,19 +305,14 @@ public: /// And the caller can still access the element of the range via the same type: /// \c std::string_view. /// -/// Current Limitation and Future Plans: -/// /// This generator implementation does not address the "Pingpong problem": /// where the producer generates elements one at a time, forcing frequent /// context switches between producer and consumer. This can lead to suboptimal /// performance, especially when bulk generation and consumption would be more /// efficient. /// -/// We intend to extend the existing implementation to allow the producer -/// to yield a range of elements. This will enable batch processing, -/// potentially improving performance by reducing context switches. -/// -/// TODO: Implement range-based yielding to mitigate the Ping-pong problem. +/// If the producer is able to produce elements in batch, please consider using +/// the buffered generator by specifying the \c Yielded template parameter. template class [[nodiscard]] generator { using value_type = std::conditional_t, @@ -371,7 +367,7 @@ public: } [[nodiscard]] auto begin() noexcept { - using base_awaiter = internal::next_awaiter; + using base_awaiter = next_awaiter; class begin_awaiter final : public base_awaiter { using base_awaiter::_promise; @@ -413,7 +409,7 @@ public: }; template -class generator::promise_type final : public internal::generator_promise_base { +class generator::promise_type final : public generator_promise_base { public: generator get_return_object() noexcept { return generator{*this}; @@ -446,7 +442,7 @@ public: } [[nodiscard]] auto operator++() noexcept { - using base_awaiter = internal::next_awaiter; + using base_awaiter = next_awaiter; class increment_awaiter final : public base_awaiter { iterator& _iterator; using base_awaiter::_promise; @@ -479,4 +475,322 @@ public: } }; +} // namespace unbuffered + +namespace buffered { + +template class next_awaiter; + +template +class generator_promise_base : public seastar::task { +protected: + // slice represents a sub sequence of the genenerated values + std::add_pointer_t _slice = nullptr; + +protected: + std::exception_ptr _exception; + std::coroutine_handle<> _consumer; + task* _waiting_task = nullptr; + + class yield_awaiter final { + generator_promise_base* _promise; + std::coroutine_handle<> _consumer; + public: + yield_awaiter(generator_promise_base* promise, + std::coroutine_handle<> consumer) noexcept + : _promise{promise} + , _consumer{consumer} + {} + bool await_ready() const noexcept { + return false; + } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle producer) noexcept { + _promise->_waiting_task = &producer.promise(); + return _consumer; + } + void await_resume() noexcept {} + }; + +public: + generator_promise_base() noexcept = default; + generator_promise_base(const generator_promise_base &) = delete; + generator_promise_base& operator=(const generator_promise_base &) = delete; + generator_promise_base(generator_promise_base &&) noexcept = default; + generator_promise_base& operator=(generator_promise_base &&) noexcept = default; + + // lazily-started coroutine, do not execute the coroutine until + // the coroutine is awaited. + std::suspend_always initial_suspend() const noexcept { + return {}; + } + + yield_awaiter final_suspend() noexcept { + _slice = nullptr; + return yield_awaiter{this, this->_consumer}; + } + + void unhandled_exception() noexcept { + _exception = std::current_exception(); + } + + yield_awaiter yield_value(Yielded slice) noexcept { + // an empty slice is forbidden, otherwise the increment_awaiter would resume + // from suspension, but end up finding nothing to return to the consumer. + assert(!std::ranges::empty(slice)); + this->_slice = std::addressof(slice); + return yield_awaiter{this, this->_consumer}; + } + + void return_void() noexcept {} + + // @return if the generator has reached the end of the sequence + bool finished() const noexcept { + return _slice == nullptr; + } + + // @return the current slice produced by the producer + Yielded& slice() noexcept { + assert(_slice); + return *_slice; + } + + + void rethrow_if_unhandled_exception() { + if (_exception) { + std::rethrow_exception(std::move(_exception)); + } + } + + void run_and_dispose() noexcept final { + using handle_type = std::coroutine_handle; + handle_type::from_promise(*this).resume(); + } + + seastar::task* waiting_task() noexcept final { + return _waiting_task; + } + +private: + friend class next_awaiter; +}; + +template +class next_awaiter { +protected: + generator_promise_base* _promise = nullptr; + std::coroutine_handle<> _producer = nullptr; + +public: + explicit next_awaiter(std::nullptr_t) noexcept {} + next_awaiter(generator_promise_base& promise, + std::coroutine_handle<> producer) noexcept + : _promise{std::addressof(promise)} + , _producer{producer} {} + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle consumer) noexcept { + _promise->_consumer = consumer; + return _producer; + } +}; + +template +class [[nodiscard]] generator { + using value_type = std::conditional_t, + std::remove_cvref_t, + Value>; + using reference_type = std::conditional_t, + Ref&&, + Ref>; + using yielded_type = Yielded; + +public: + class promise_type; + +private: + using handle_type = std::coroutine_handle; + handle_type _coro = {}; + +public: + class iterator; + + generator() noexcept = default; + explicit generator(promise_type& promise) noexcept + : _coro(std::coroutine_handle::from_promise(promise)) + {} + generator(generator&& other) noexcept + : _coro{std::exchange(other._coro, {})} + {} + generator(const generator&) = delete; + generator& operator=(const generator&) = delete; + + ~generator() { + if (_coro) { + _coro.destroy(); + } + } + + friend void swap(generator& lhs, generator& rhs) noexcept { + std::swap(lhs._coro, rhs._coro); + } + + generator& operator=(generator&& other) noexcept { + if (this == &other) { + return *this; + } + if (_coro) { + _coro.destroy(); + } + _coro = std::exchange(other._coro, nullptr); + return *this; + } + + [[nodiscard]] auto begin() noexcept { + using base_awaiter = next_awaiter; + class begin_awaiter final : public base_awaiter { + using base_awaiter::_promise; + + public: + using base_awaiter::base_awaiter; + bool await_ready() const noexcept { + return _promise == nullptr; + } + + iterator await_resume() { + if (_promise == nullptr) { + return iterator{nullptr}; + } + if (_promise->finished()) { + _promise->rethrow_if_unhandled_exception(); + return iterator{nullptr}; + } + return iterator{ + handle_type::from_promise(*static_cast(_promise)) + }; + } + }; + + if (_coro && !_coro.done()) { + return begin_awaiter{_coro.promise(), _coro}; + } else { + return begin_awaiter{nullptr}; + } + } + + [[nodiscard]] std::default_sentinel_t end() const noexcept { + return {}; + } +}; + +/// buffered generator has 3 template parameters: +/// +/// - Ref +/// - Value +/// - Yielded +/// +/// Unlike its unbuffered variant, the \c Yielded type can be customized. \c Yielded should be a +/// range of elements, which are convertible to the \c Value type. +/// +/// @note Please note, empty ranges are not allowed to be yielded. +template +class generator::promise_type final : public generator_promise_base { +public: + generator get_return_object() noexcept { + return generator{*this}; + } +}; + +template +class generator::iterator final { +private: + using handle_type = generator::handle_type; + // nullptr on end + handle_type _coro = nullptr; + + using yielded_iterator = std::ranges::iterator_t; + static_assert(std::input_iterator); + yielded_iterator _iterator; + + void reset() { + _iterator = std::begin(_coro.promise().slice()); + } + +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = generator::value_type; + using reference = generator::reference_type; + using pointer = std::add_pointer_t; + + explicit iterator(handle_type coroutine) noexcept + : _coro{coroutine} { + if (_coro) { + reset(); + } + } + + explicit operator bool() const noexcept { + return _coro && !_coro.done(); + } + + [[nodiscard]] auto operator++() noexcept { + using base_awaiter = next_awaiter; + class increment_awaiter final : public base_awaiter { + const bool _in_this_slice; + iterator& _iterator; + using base_awaiter::_promise; + + public: + explicit increment_awaiter(bool is_ready, iterator& iterator) noexcept + : base_awaiter{iterator._coro.promise(), iterator._coro} + , _in_this_slice{is_ready} + , _iterator{iterator} + {} + + bool await_ready() const noexcept { + return _in_this_slice; + } + + iterator& await_resume() { + if (_promise->finished()) { + // update iterator to end() + _iterator = iterator{nullptr}; + _promise->rethrow_if_unhandled_exception(); + return _iterator; + } + if (!_in_this_slice) { + _iterator.reset(); + } + return _iterator; + } + }; + + assert(bool(*this) && "cannot increment end iterator"); + bool is_ready = ++_iterator != std::ranges::end(_coro.promise().slice()); + return increment_awaiter{is_ready, *this}; + } + + reference operator*() const noexcept { + return static_cast(*_iterator); + } + + bool operator==(std::default_sentinel_t) const noexcept { + return !bool(*this); + } +}; + +} // namespace buffered + +template +concept element_of = + !std::convertible_to && + std::convertible_to, T>; +} // namespace internal + +template +using generator = std::conditional_t, + internal::buffered::generator, + internal::unbuffered::generator>; + } // namespace seastar::coroutine::experimental diff --git a/tests/unit/generator_test.cc b/tests/unit/generator_test.cc index 9b438b06f0..e0fc60b951 100644 --- a/tests/unit/generator_test.cc +++ b/tests/unit/generator_test.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #if __cplusplus >= 202302L && defined(__cpp_lib_generator) #include @@ -234,3 +235,150 @@ SEASTAR_TEST_CASE(test_generator_throws_from_consumer) { BOOST_REQUIRE_THROW(std::rethrow_exception(f.get_exception()), std::invalid_argument); BOOST_REQUIRE_EQUAL(total, 0); } + +SEASTAR_TEST_CASE(test_batch_generator_empty_sequence) { + using int_gen = coroutine::experimental::generator>; + auto seq = std::invoke([]() -> int_gen { + co_return; + }); + for (auto i = co_await seq.begin(); i != seq.end(); co_await ++i) { + BOOST_FAIL("found element in an empty sequence"); + } +} + +coroutine::experimental::generator> +async_fibonacci_sequence_batch(unsigned count, unsigned batch_size, do_suspend suspend) { + auto a = 0, b = 1; + std::vector batch; + for (unsigned i = 0; i < count; ++i) { + if (std::numeric_limits::max() - a < b) { + throw std::out_of_range( + fmt::format("fibonacci[{}] is greater than the largest value of int", i)); + } + if (suspend) { + co_await yield(); + } + int next = std::exchange(a, std::exchange(b, a + b)); + batch.push_back(next); + if (batch.size() == batch_size) { + co_yield std::exchange(batch, {}); + } + } + if (!batch.empty()) { + co_yield std::move(batch); + } +} + +seastar::future<> +verify_fib_drained(coroutine::experimental::generator> actual_fibs, unsigned count) { + auto expected_fibs = sync_fibonacci_sequence(count); + auto expected_fib = std::begin(expected_fibs); + + auto actual_fib = co_await actual_fibs.begin(); + + for (; actual_fib != actual_fibs.end(); co_await ++actual_fib) { + BOOST_REQUIRE(expected_fib != std::end(expected_fibs)); + BOOST_REQUIRE_EQUAL(*actual_fib, *expected_fib); + ++expected_fib; + } + BOOST_REQUIRE(actual_fib == actual_fibs.end()); +} + +SEASTAR_TEST_CASE(test_batch_generator_drained_with_suspend) { + constexpr unsigned count = 4; + constexpr unsigned batch_size = 2; + return verify_fib_drained(async_fibonacci_sequence_batch(count, batch_size, do_suspend::yes), + count); +} + +SEASTAR_TEST_CASE(test_batch_generator_drained_without_suspend) { + constexpr int count = 4; + constexpr int batch_size = 2; + return verify_fib_drained(async_fibonacci_sequence_batch(count, batch_size, do_suspend::no), + count); +} + +seastar::future<> test_batch_generator_not_drained(do_suspend suspend) { + auto fib = async_fibonacci_sequence_batch(42, 12, suspend); + auto actual_fib = co_await fib.begin(); + BOOST_REQUIRE_EQUAL(*actual_fib, 0); +} + +SEASTAR_TEST_CASE(test_batch_generator_not_drained_with_suspend) { + return test_batch_generator_not_drained(do_suspend::yes); +} + +SEASTAR_TEST_CASE(test_batch_generator_not_drained_without_suspend) { + return test_batch_generator_not_drained(do_suspend::no); +} + +SEASTAR_TEST_CASE(test_batch_generator_move_away) { + struct move_only { + int value; + move_only(int value) + : value{value} + {} + move_only(const move_only&) = delete; + move_only& operator=(const move_only&) = delete; + move_only(move_only&&) noexcept = default; + move_only& operator=(move_only&&) noexcept = default; + }; + + using batch_type = std::vector; + using move_only_gen = coroutine::experimental::generator; + + constexpr int count = 4; + constexpr unsigned batch_size = 2; + auto numbers = std::invoke([]() -> move_only_gen { + batch_type batch; + for (int i = 0; i < count; i++) { + batch.push_back(i); + if (batch.size() == batch_size) { + co_yield std::exchange(batch, {}); + } + } + if (!batch.empty()) { + co_yield std::move(batch); + } + }); + + int expected_n = 0; + for (auto n = co_await numbers.begin(); n != numbers.end(); co_await ++n) { + BOOST_REQUIRE_EQUAL((*n).value, expected_n++); + } +} + +SEASTAR_TEST_CASE(test_batch_generator_convertible) { + struct convertible { + const std::string value; + convertible(std::string&& value) + : value{std::move(value)} + {} + explicit operator int() const { + return std::stoi(value); + } + }; + + using batch_type = std::vector; + using move_only_gen = coroutine::experimental::generator; + + constexpr int count = 4; + constexpr unsigned batch_size = 2; + auto numbers = std::invoke([]() -> move_only_gen { + batch_type batch; + for (int i = 0; i < count; i++) { + batch.push_back(fmt::to_string(i)); + if (batch.size() == batch_size) { + co_yield std::exchange(batch, {}); + } + } + if (!batch.empty()) { + co_yield std::move(batch); + } + }); + + int expected_n = 0; + for (auto n = co_await numbers.begin(); n != numbers.end(); co_await ++n) { + BOOST_REQUIRE_EQUAL(*n, expected_n++); + } +}