From c61590255994b872bb6fd7c07046c20d94a8955e Mon Sep 17 00:00:00 2001 From: Marcelo Zimbres Date: Sun, 10 Nov 2024 15:36:09 +0100 Subject: [PATCH 1/2] Simplifications - Removes cancellation support from async_run. - Simplifies async operations. - Removes async_run_lean. - Remove support for implicit cancellation from health-checker. - Reuses async_run parallel group for health-checker tasks. - Moves the resolver from the runner to the connection. - Moves the connector from the runner to the connection_base class. - Moves the ssl handshaker to the connection_base class. - Moves the health-check to the connection_base class. - Simplifies cancel operations. - Improvements in logging. - Remove ssl handshaker. - Removes run_op from runner and renames runner to resp3_handshaker. --- include/boost/redis/connection.hpp | 71 +--- .../boost/redis/detail/connection_base.hpp | 346 +++++++++++------- include/boost/redis/detail/handshaker.hpp | 124 ------- include/boost/redis/detail/health_checker.hpp | 134 ++----- include/boost/redis/detail/resolver.hpp | 15 +- .../boost/redis/detail/resp3_handshaker.hpp | 116 ++++++ include/boost/redis/detail/runner.hpp | 142 +++---- include/boost/redis/error.hpp | 3 + include/boost/redis/impl/error.ipp | 1 + include/boost/redis/impl/logger.ipp | 83 +---- .../impl/{runner.ipp => resp3_handshaker.ipp} | 2 +- include/boost/redis/logger.hpp | 36 +- include/boost/redis/src.hpp | 2 +- test/CMakeLists.txt | 1 - test/common.hpp | 2 +- test/test_conn_exec_retry.cpp | 9 +- test/test_conn_quit.cpp | 18 - test/test_conn_run_cancel.cpp | 153 -------- test/test_low_level.cpp | 1 + test/test_low_level_sync_sans_io.cpp | 4 +- 20 files changed, 489 insertions(+), 774 deletions(-) delete mode 100644 include/boost/redis/detail/handshaker.hpp create mode 100644 include/boost/redis/detail/resp3_handshaker.hpp rename include/boost/redis/impl/{runner.ipp => resp3_handshaker.ipp} (94%) delete mode 100644 test/test_conn_run_cancel.cpp diff --git a/include/boost/redis/connection.hpp b/include/boost/redis/connection.hpp index 3089dfc5..b9c3654a 100644 --- a/include/boost/redis/connection.hpp +++ b/include/boost/redis/connection.hpp @@ -22,42 +22,6 @@ #include namespace boost::redis { -namespace detail -{ -template -struct reconnection_op { - Connection* conn_ = nullptr; - Logger logger_; - asio::coroutine coro_{}; - - template - void operator()(Self& self, system::error_code ec = {}) - { - BOOST_ASIO_CORO_REENTER (coro_) for (;;) - { - BOOST_ASIO_CORO_YIELD - conn_->impl_.async_run(conn_->cfg_, logger_, std::move(self)); - conn_->cancel(operation::receive); - logger_.on_connection_lost(ec); - if (!conn_->will_reconnect() || is_cancelled(self)) { - conn_->cancel(operation::reconnection); - self.complete(!!ec ? ec : asio::error::operation_aborted); - return; - } - - conn_->timer_.expires_after(conn_->cfg_.reconnect_wait_interval); - BOOST_ASIO_CORO_YIELD - conn_->timer_.async_wait(std::move(self)); - BOOST_REDIS_CHECK_OP0(;) - if (!conn_->will_reconnect()) { - self.complete(asio::error::operation_aborted); - return; - } - conn_->reset_stream(); - } - } -}; -} // detail /** @brief A SSL connection to the Redis server. * @ingroup high-level-api @@ -100,7 +64,6 @@ class basic_connection { asio::ssl::context ctx = asio::ssl::context{asio::ssl::context::tlsv12_client}, std::size_t max_read_size = (std::numeric_limits::max)()) : impl_{ex, std::move(ctx), max_read_size} - , timer_{ex} { } /// Contructs from a context. @@ -158,14 +121,7 @@ class basic_connection { Logger l = Logger{}, CompletionToken token = CompletionToken{}) { - using this_type = basic_connection; - - cfg_ = cfg; - l.set_prefix(cfg_.log_prefix); - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(detail::reconnection_op{this, l}, token, timer_); + return impl_.async_run(cfg, l, std::move(token)); } /** @brief Receives server side pushes asynchronously. @@ -287,22 +243,11 @@ class basic_connection { * @param op: The operation to be cancelled. */ void cancel(operation op = operation::all) - { - switch (op) { - case operation::reconnection: - case operation::all: - cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); - timer_.cancel(); - break; - default: /* ignore */; - } - - impl_.cancel(op); - } + { impl_.cancel(op); } /// Returns true if the connection was canceled. bool will_reconnect() const noexcept - { return cfg_.reconnect_wait_interval != std::chrono::seconds::zero();} + { return impl_.will_reconnect();} /// Returns the ssl context. auto const& get_ssl_context() const noexcept @@ -330,17 +275,7 @@ class basic_connection { { return impl_.get_usage(); } private: - using timer_type = - asio::basic_waitable_timer< - std::chrono::steady_clock, - asio::wait_traits, - Executor>; - - template friend struct detail::reconnection_op; - - config cfg_; detail::connection_base impl_; - timer_type timer_; }; /** \brief A basic_connection that type erases the executor. diff --git a/include/boost/redis/detail/connection_base.hpp b/include/boost/redis/detail/connection_base.hpp index be46703b..68bad54d 100644 --- a/include/boost/redis/detail/connection_base.hpp +++ b/include/boost/redis/detail/connection_base.hpp @@ -9,37 +9,42 @@ #include #include +#include +#include +#include #include +#include +#include #include #include #include #include -#include -#include #include -#include +#include #include #include +#include +#include +#include +#include #include #include +#include +#include +#include #include #include #include #include -#include -#include -#include -#include -#include #include #include #include #include +#include #include #include -#include #include namespace boost::redis::detail @@ -176,47 +181,6 @@ struct exec_op { } }; -template -struct run_op { - Conn* conn = nullptr; - Logger logger_; - asio::coroutine coro{}; - - template - void operator()( Self& self - , std::array order = {} - , system::error_code ec0 = {} - , system::error_code ec1 = {}) - { - BOOST_ASIO_CORO_REENTER (coro) - { - conn->reset(); - - BOOST_ASIO_CORO_YIELD - asio::experimental::make_parallel_group( - [this](auto token) { return conn->reader(logger_, token);}, - [this](auto token) { return conn->writer(logger_, token);} - ).async_wait( - asio::experimental::wait_for_one(), - std::move(self)); - - if (is_cancelled(self)) { - logger_.trace("run-op: canceled. Exiting ..."); - self.complete(asio::error::operation_aborted); - return; - } - - logger_.on_run(ec0, ec1); - - switch (order[0]) { - case 0: self.complete(ec0); break; - case 1: self.complete(ec1); break; - default: BOOST_ASSERT(false); - } - } - } -}; - template struct writer_op { Conn* conn_; @@ -241,25 +205,19 @@ struct writer_op { logger_.on_write(ec, conn_->write_buffer_); if (ec) { - logger_.trace("writer-op: error. Exiting ..."); + logger_.trace("writer_op (1)", ec); conn_->cancel(operation::run); self.complete(ec); return; } - if (is_cancelled(self)) { - logger_.trace("writer-op: canceled. Exiting ..."); - self.complete(asio::error::operation_aborted); - return; - } - conn_->on_write(); // A socket.close() may have been called while a // successful write might had already been queued, so we // have to check here before proceeding. if (!conn_->is_open()) { - logger_.trace("writer-op: canceled (2). Exiting ..."); + logger_.trace("writer_op (2): connection is closed."); self.complete({}); return; } @@ -267,8 +225,8 @@ struct writer_op { BOOST_ASIO_CORO_YIELD conn_->writer_timer_.async_wait(std::move(self)); - if (!conn_->is_open() || is_cancelled(self)) { - logger_.trace("writer-op: canceled (3). Exiting ..."); + if (!conn_->is_open()) { + logger_.trace("writer_op (3): connection is closed."); // Notice this is not an error of the op, stoping was // requested from the outside, so we complete with // success. @@ -317,16 +275,9 @@ struct reader_op { logger_.on_read(ec, n); - // EOF is not treated as error. - if (ec == asio::error::eof) { - logger_.trace("reader-op: EOF received. Exiting ..."); - conn_->cancel(operation::run); - return self.complete(ec); - } - // The connection is not viable after an error. if (ec) { - logger_.trace("reader-op: error. Exiting ..."); + logger_.trace("reader_op (1)", ec); conn_->cancel(operation::run); self.complete(ec); return; @@ -335,8 +286,8 @@ struct reader_op { // Somebody might have canceled implicitly or explicitly // while we were suspended and after queueing so we have to // check. - if (!conn_->is_open() || is_cancelled(self)) { - logger_.trace("reader-op: canceled. Exiting ..."); + if (!conn_->is_open()) { + logger_.trace("reader_op (2): connection is closed."); self.complete(ec); return; } @@ -344,7 +295,7 @@ struct reader_op { res_ = conn_->on_read(buffer_view(conn_->dbuf_), ec); if (ec) { - logger_.trace("reader-op: parse error. Exiting ..."); + logger_.trace("reader_op (3)", ec); conn_->cancel(operation::run); self.complete(ec); return; @@ -357,14 +308,14 @@ struct reader_op { } if (ec) { - logger_.trace("reader-op: error. Exiting ..."); + logger_.trace("reader_op (4)", ec); conn_->cancel(operation::run); self.complete(ec); return; } - if (!conn_->is_open() || is_cancelled(self)) { - logger_.trace("reader-op: canceled (2). Exiting ..."); + if (!conn_->is_open()) { + logger_.trace("reader_op (5): connection is closed."); self.complete(asio::error::operation_aborted); return; } @@ -374,6 +325,135 @@ struct reader_op { } }; +template +class run_op { +private: + Conn* conn_ = nullptr; + Logger logger_; + asio::coroutine coro_{}; + + using order_t = std::array; + +public: + run_op(Conn* conn, Logger l) + : conn_{conn} + , logger_{l} + {} + + template + void operator()( Self& self + , order_t order = {} + , system::error_code ec0 = {} + , system::error_code ec1 = {} + , system::error_code ec2 = {} + , system::error_code ec3 = {} + , system::error_code ec4 = {}) + { + BOOST_ASIO_CORO_REENTER (coro_) for (;;) + { + BOOST_ASIO_CORO_YIELD + conn_->resv_.async_resolve(asio::prepend(std::move(self), order_t {})); + + logger_.on_resolve(ec0, conn_->resv_.results()); + + if (ec0) { + self.complete(ec0); + return; + } + + BOOST_ASIO_CORO_YIELD + conn_->ctor_.async_connect( + conn_->next_layer().next_layer(), + conn_->resv_.results(), + asio::prepend(std::move(self), order_t {})); + + logger_.on_connect(ec0, conn_->ctor_.endpoint()); + + if (ec0) { + self.complete(ec0); + return; + } + + if (conn_->use_ssl()) { + BOOST_ASIO_CORO_YIELD + conn_->next_layer().async_handshake( + asio::ssl::stream_base::client, + asio::prepend( + asio::cancel_after( + conn_->cfg_.ssl_handshake_timeout, + std::move(self) + ), + order_t {} + ) + ); + + logger_.on_ssl_handshake(ec0); + + if (ec0) { + self.complete(ec0); + return; + } + } + + conn_->reset(); + + // Note: Oder is important here because the writer might + // trigger an async_write before the async_hello thereby + // causing an authentication problem. + BOOST_ASIO_CORO_YIELD + asio::experimental::make_parallel_group( + [this](auto token) { return conn_->handshaker_.async_hello(*conn_, logger_, token); }, + [this](auto token) { return conn_->health_checker_.async_ping(*conn_, logger_, token); }, + [this](auto token) { return conn_->health_checker_.async_check_timeout(*conn_, logger_, token);}, + [this](auto token) { return conn_->reader(logger_, token);}, + [this](auto token) { return conn_->writer(logger_, token);} + ).async_wait( + asio::experimental::wait_for_one_error(), + std::move(self)); + + if (order[0] == 0 && !!ec0) { + self.complete(ec0); + return; + } + + if (order[0] == 2 && ec2 == error::pong_timeout) { + self.complete(ec1); + return; + } + + // The receive operation must be cancelled because channel + // subscription does not survive a reconnection but requires + // re-subscription. + conn_->cancel(operation::receive); + + if (!conn_->will_reconnect()) { + conn_->cancel(operation::reconnection); + self.complete(ec3); + return; + } + + // It is safe to use the writer timer here because we are not + // connected. + conn_->writer_timer_.expires_after(conn_->cfg_.reconnect_wait_interval); + + BOOST_ASIO_CORO_YIELD + conn_->writer_timer_.async_wait(asio::prepend(std::move(self), order_t {})); + if (ec0) { + self.complete(ec0); + return; + } + + if (!conn_->will_reconnect()) { + self.complete(asio::error::operation_aborted); + return; + } + + conn_->reset_stream(); + } + } +}; + + /** @brief Base class for high level Redis asynchronous connections. * @ingroup high-level-api * @@ -404,7 +484,8 @@ class connection_base { , stream_{std::make_unique(ex, ctx_)} , writer_timer_{ex} , receive_channel_{ex, 256} - , runner_{ex, {}} + , resv_{ex} + , health_checker_{ex} , dbuf_{read_buffer_, max_read_size} { set_receive_response(ignore); @@ -433,15 +514,35 @@ class connection_base { /// Cancels specific operations. void cancel(operation op) { - runner_.cancel(op); - if (op == operation::all) { - cancel_impl(operation::run); - cancel_impl(operation::receive); - cancel_impl(operation::exec); - return; - } - - cancel_impl(op); + switch (op) { + case operation::resolve: + resv_.cancel(); + break; + case operation::exec: + cancel_unwritten_requests(); + break; + case operation::reconnection: + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + break; + case operation::run: + cancel_run(); + break; + case operation::receive: + receive_channel_.cancel(); + break; + case operation::health_check: + health_checker_.cancel(); + break; + case operation::all: + resv_.cancel(); + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + health_checker_.cancel(); + cancel_run(); // run + receive_channel_.cancel(); // receive + cancel_unwritten_requests(); // exec + break; + default: /* ignore */; + } } template @@ -493,9 +594,17 @@ class connection_base { template auto async_run(config const& cfg, Logger l, CompletionToken token) { - runner_.set_config(cfg); - l.set_prefix(runner_.get_config().log_prefix); - return runner_.async_run(*this, l, std::move(token)); + cfg_ = cfg; + resv_.set_config(cfg); + ctor_.set_config(cfg); + health_checker_.set_config(cfg); + handshaker_.set_config(cfg); + l.set_prefix(cfg.log_prefix); + + return asio::async_compose + < CompletionToken + , void(system::error_code) + >(run_op{this, l}, token, writer_timer_); } template @@ -512,15 +621,20 @@ class connection_base { auto run_is_canceled() const noexcept { return cancel_run_called_; } + bool will_reconnect() const noexcept + { return cfg_.reconnect_wait_interval != std::chrono::seconds::zero();} + private: using receive_channel_type = asio::experimental::channel; - using runner_type = runner; + using resolver_type = resolver; + using health_checker_type = health_checker; + using resp3_handshaker_type = resp3_handshaker; using adapter_type = std::function const&, system::error_code&)>; using receiver_adapter_type = std::function const&, system::error_code&)>; using exec_notifier_type = receive_channel_type; auto use_ssl() const noexcept - { return runner_.get_config().use_ssl;} + { return cfg_.use_ssl;} auto cancel_on_conn_lost() -> std::size_t { @@ -573,32 +687,18 @@ class connection_base { return ret; } - void cancel_impl(operation op) + void cancel_run() { - switch (op) { - case operation::exec: - { - cancel_unwritten_requests(); - } break; - case operation::run: - { - // Protects the code below from being called more than - // once, see https://github.com/boostorg/redis/issues/181 - if (std::exchange(cancel_run_called_, true)) { - return; - } - - close(); - writer_timer_.cancel(); - receive_channel_.cancel(); - cancel_on_conn_lost(); - } break; - case operation::receive: - { - receive_channel_.cancel(); - } break; - default: /* ignore */; + // Protects the code below from being called more than + // once, see https://github.com/boostorg/redis/issues/181 + if (std::exchange(cancel_run_called_, true)) { + return; } + + close(); + writer_timer_.cancel(); + receive_channel_.cancel(); + cancel_on_conn_lost(); } void on_write() @@ -706,9 +806,8 @@ class connection_base { template friend struct reader_op; template friend struct writer_op; - template friend struct run_op; template friend struct exec_op; - template friend struct runner_op; + template friend class run_op; void cancel_push_requests() { @@ -762,17 +861,6 @@ class connection_base { >(writer_op{this, l}, token, writer_timer_); } - template - auto async_run_lean(config const& cfg, Logger l, CompletionToken token) - { - runner_.set_config(cfg); - l.set_prefix(runner_.get_config().log_prefix); - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(run_op{this, l}, token, writer_timer_); - } - [[nodiscard]] bool coalesce_requests() { // Coalesces the requests and marks them staged. After a @@ -946,11 +1034,15 @@ class connection_base { // not suspend. timer_type writer_timer_; receive_channel_type receive_channel_; - runner_type runner_; + resolver_type resv_; + connector ctor_; + health_checker_type health_checker_; + resp3_handshaker_type handshaker_; receiver_adapter_type receive_adapter_; using dyn_buffer_type = asio::dynamic_string_buffer, std::allocator>; + config cfg_; std::string read_buffer_; dyn_buffer_type dbuf_; std::string write_buffer_; diff --git a/include/boost/redis/detail/handshaker.hpp b/include/boost/redis/detail/handshaker.hpp deleted file mode 100644 index 0338d3cc..00000000 --- a/include/boost/redis/detail/handshaker.hpp +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2018-2024 Marcelo Zimbres Silva (mzimbres@gmail.com) - * - * Distributed under the Boost Software License, Version 1.0. (See - * accompanying file LICENSE.txt) - */ - -#ifndef BOOST_REDIS_SSL_CONNECTOR_HPP -#define BOOST_REDIS_SSL_CONNECTOR_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace boost::redis::detail -{ - -template -struct handshake_op { - Handshaker* hsher_ = nullptr; - Stream* stream_ = nullptr; - asio::coroutine coro{}; - - template - void operator()( Self& self - , std::array const& order = {} - , system::error_code const& ec1 = {} - , system::error_code const& ec2 = {}) - { - BOOST_ASIO_CORO_REENTER (coro) - { - hsher_->timer_.expires_after(hsher_->timeout_); - - BOOST_ASIO_CORO_YIELD - asio::experimental::make_parallel_group( - [this](auto token) { return stream_->async_handshake(asio::ssl::stream_base::client, token); }, - [this](auto token) { return hsher_->timer_.async_wait(token);} - ).async_wait( - asio::experimental::wait_for_one(), - std::move(self)); - - if (is_cancelled(self)) { - self.complete(asio::error::operation_aborted); - return; - } - - switch (order[0]) { - case 0: { - self.complete(ec1); - } break; - case 1: - { - if (ec2) { - self.complete(ec2); - } else { - self.complete(error::ssl_handshake_timeout); - } - } break; - - default: BOOST_ASSERT(false); - } - } - } -}; - -template -class handshaker { -public: - using timer_type = - asio::basic_waitable_timer< - std::chrono::steady_clock, - asio::wait_traits, - Executor>; - - handshaker(Executor ex) - : timer_{ex} - {} - - template - auto - async_handshake(Stream& stream, CompletionToken&& token) - { - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(handshake_op{this, &stream}, token, timer_); - } - - std::size_t cancel(operation op) - { - switch (op) { - case operation::ssl_handshake: - case operation::all: - timer_.cancel(); - break; - default: /* ignore */; - } - - return 0; - } - - constexpr bool is_dummy() const noexcept - {return false;} - - void set_config(config const& cfg) - { timeout_ = cfg.ssl_handshake_timeout; } - -private: - template friend struct handshake_op; - - timer_type timer_; - std::chrono::steady_clock::duration timeout_; -}; - -} // boost::redis::detail - -#endif // BOOST_REDIS_SSL_CONNECTOR_HPP diff --git a/include/boost/redis/detail/health_checker.hpp b/include/boost/redis/detail/health_checker.hpp index 02062553..3ac6c948 100644 --- a/include/boost/redis/detail/health_checker.hpp +++ b/include/boost/redis/detail/health_checker.hpp @@ -7,19 +7,17 @@ #ifndef BOOST_REDIS_HEALTH_CHECKER_HPP #define BOOST_REDIS_HEALTH_CHECKER_HPP -// Has to included before promise.hpp to build on msvc. #include #include #include #include -#include #include +#include #include #include #include #include #include -#include #include #include @@ -38,28 +36,37 @@ class ping_op { { BOOST_ASIO_CORO_REENTER (coro_) for (;;) { + if (checker_->ping_interval_ == std::chrono::seconds::zero()) { + logger_.trace("ping_op (1): timeout disabled."); + BOOST_ASIO_CORO_YIELD + asio::post(std::move(self)); + self.complete({}); + return; + } + if (checker_->checker_has_exited_) { - logger_.trace("ping_op: checker has exited. Exiting ..."); + logger_.trace("ping_op (2): checker has exited."); self.complete({}); return; } BOOST_ASIO_CORO_YIELD conn_->async_exec(checker_->req_, any_adapter(checker_->resp_), std::move(self)); - if (ec || is_cancelled(self)) { - logger_.trace("ping_op: error/cancelled (1)."); + if (ec) { + logger_.trace("ping_op (3)", ec); checker_->wait_timer_.cancel(); - self.complete(!!ec ? ec : asio::error::operation_aborted); + self.complete(ec); return; } // Wait before pinging again. checker_->ping_timer_.expires_after(checker_->ping_interval_); + BOOST_ASIO_CORO_YIELD checker_->ping_timer_.async_wait(std::move(self)); - if (ec || is_cancelled(self)) { - logger_.trace("ping_op: error/cancelled (2)."); - self.complete(!!ec ? ec : asio::error::operation_aborted); + if (ec) { + logger_.trace("ping_op (4)", ec); + self.complete(ec); return; } } @@ -79,23 +86,33 @@ class check_timeout_op { { BOOST_ASIO_CORO_REENTER (coro_) for (;;) { + if (checker_->ping_interval_ == std::chrono::seconds::zero()) { + logger_.trace("check_timeout_op (1): timeout disabled."); + BOOST_ASIO_CORO_YIELD + asio::post(std::move(self)); + self.complete({}); + return; + } + checker_->wait_timer_.expires_after(2 * checker_->ping_interval_); + BOOST_ASIO_CORO_YIELD checker_->wait_timer_.async_wait(std::move(self)); - if (ec || is_cancelled(self)) { - logger_.trace("check-timeout-op: error/canceled. Exiting ..."); - self.complete(!!ec ? ec : asio::error::operation_aborted); + if (ec) { + logger_.trace("check_timeout_op (2)", ec); + self.complete(ec); return; } if (checker_->resp_.has_error()) { - logger_.trace("check-timeout-op: Response error. Exiting ..."); + // TODO: Log the error. + logger_.trace("check_timeout_op (3): Response error."); self.complete({}); return; } if (checker_->resp_.value().empty()) { - logger_.trace("check-timeout-op: Response has no value. Exiting ..."); + logger_.trace("check_timeout_op (4): pong timeout."); checker_->ping_timer_.cancel(); conn_->cancel(operation::run); checker_->checker_has_exited_ = true; @@ -110,57 +127,6 @@ class check_timeout_op { } }; -template -class check_health_op { -public: - HealthChecker* checker_ = nullptr; - Connection* conn_ = nullptr; - Logger logger_; - asio::coroutine coro_{}; - - template - void - operator()( - Self& self, - std::array order = {}, - system::error_code ec1 = {}, - system::error_code ec2 = {}) - { - BOOST_ASIO_CORO_REENTER (coro_) - { - if (checker_->ping_interval_ == std::chrono::seconds::zero()) { - logger_.trace("check-health-op: timeout disabled."); - BOOST_ASIO_CORO_YIELD - asio::post(std::move(self)); - self.complete({}); - return; - } - - BOOST_ASIO_CORO_YIELD - asio::experimental::make_parallel_group( - [this](auto token) { return checker_->async_ping(*conn_, logger_, token); }, - [this](auto token) { return checker_->async_check_timeout(*conn_, logger_, token);} - ).async_wait( - asio::experimental::wait_for_one(), - std::move(self)); - - logger_.on_check_health(ec1, ec2); - - if (is_cancelled(self)) { - logger_.trace("check-health-op: canceled. Exiting ..."); - self.complete(asio::error::operation_aborted); - return; - } - - switch (order[0]) { - case 0: self.complete(ec1); return; - case 1: self.complete(ec2); return; - default: BOOST_ASSERT(false); - } - } - } -}; - template class health_checker { private: @@ -185,39 +151,12 @@ class health_checker { ping_interval_ = cfg.health_check_interval; } - template < - class Connection, - class Logger, - class CompletionToken = asio::default_completion_token_t - > - auto - async_check_health( - Connection& conn, - Logger l, - CompletionToken token = CompletionToken{}) + void cancel() { - checker_has_exited_ = false; - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(check_health_op{this, &conn, l}, token, conn); + ping_timer_.cancel(); + wait_timer_.cancel(); } - std::size_t cancel(operation op) - { - switch (op) { - case operation::health_check: - case operation::all: - ping_timer_.cancel(); - wait_timer_.cancel(); - break; - default: /* ignore */; - } - - return 0; - } - -private: template auto async_ping(Connection& conn, Logger l, CompletionToken token) { @@ -230,15 +169,16 @@ class health_checker { template auto async_check_timeout(Connection& conn, Logger l, CompletionToken token) { + checker_has_exited_ = false; return asio::async_compose < CompletionToken , void(system::error_code) >(check_timeout_op{this, &conn, l}, token, conn, wait_timer_); } +private: template friend class ping_op; template friend class check_timeout_op; - template friend class check_health_op; timer_type ping_timer_; timer_type wait_timer_; diff --git a/include/boost/redis/detail/resolver.hpp b/include/boost/redis/detail/resolver.hpp index 4eb00457..c55def30 100644 --- a/include/boost/redis/detail/resolver.hpp +++ b/include/boost/redis/detail/resolver.hpp @@ -8,7 +8,6 @@ #define BOOST_REDIS_RESOLVER_HPP #include -#include #include #include #include @@ -63,18 +62,8 @@ class resolver { >(resolve_op{this}, token, resv_); } - std::size_t cancel(operation op) - { - switch (op) { - case operation::resolve: - case operation::all: - resv_.cancel(); - break; - default: /* ignore */; - } - - return 0; - } + void cancel() + { resv_.cancel(); } auto const& results() const noexcept { return results_;} diff --git a/include/boost/redis/detail/resp3_handshaker.hpp b/include/boost/redis/detail/resp3_handshaker.hpp new file mode 100644 index 00000000..2bc83996 --- /dev/null +++ b/include/boost/redis/detail/resp3_handshaker.hpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2018-2024 Marcelo Zimbres Silva (mzimbres@gmail.com) + * + * Distributed under the Boost Software License, Version 1.0. (See + * accompanying file LICENSE.txt) + */ + +#ifndef BOOST_REDIS_RUNNER_HPP +#define BOOST_REDIS_RUNNER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +//#include +#include +#include +#include + +namespace boost::redis::detail +{ + +void push_hello(config const& cfg, request& req); + +// TODO: Can we avoid this whole function whose only purpose is to +// check for an error in the hello response and complete with an error +// so that the parallel group that starts it can exit? +template +struct hello_op { + Handshaker* handshaker_ = nullptr; + Connection* conn_ = nullptr; + Logger logger_; + asio::coroutine coro_{}; + + template + void operator()(Self& self, system::error_code ec = {}, std::size_t = 0) + { + BOOST_ASIO_CORO_REENTER (coro_) + { + handshaker_->add_hello(); + + BOOST_ASIO_CORO_YIELD + conn_->async_exec(handshaker_->hello_req_, any_adapter(handshaker_->hello_resp_), std::move(self)); + logger_.on_hello(ec, handshaker_->hello_resp_); + + if (ec) { + conn_->cancel(operation::run); + self.complete(ec); + return; + } + + if (handshaker_->has_error_in_response()) { + conn_->cancel(operation::run); + self.complete(error::resp3_hello); + return; + } + + self.complete({}); + } + } +}; + +template +class resp3_handshaker { +public: + void set_config(config const& cfg) + { cfg_ = cfg; } + + template + auto async_hello(Connection& conn, Logger l, CompletionToken token) + { + return asio::async_compose + < CompletionToken + , void(system::error_code) + >(hello_op{this, &conn, l}, token, conn); + } + +private: + template friend struct hello_op; + + void add_hello() + { + hello_req_.clear(); + if (hello_resp_.has_value()) + hello_resp_.value().clear(); + push_hello(cfg_, hello_req_); + } + + bool has_error_in_response() const noexcept + { + if (!hello_resp_.has_value()) + return true; + + auto f = [](auto const& e) + { + switch (e.data_type) { + case resp3::type::simple_error: + case resp3::type::blob_error: return true; + default: return false; + } + }; + + return std::any_of(std::cbegin(hello_resp_.value()), std::cend(hello_resp_.value()), f); + } + + request hello_req_; + generic_response hello_resp_; + config cfg_; +}; + +} // boost::redis::detail + +#endif // BOOST_REDIS_RUNNER_HPP diff --git a/include/boost/redis/detail/runner.hpp b/include/boost/redis/detail/runner.hpp index 1f8f8d5c..edae4e2c 100644 --- a/include/boost/redis/detail/runner.hpp +++ b/include/boost/redis/detail/runner.hpp @@ -8,22 +8,21 @@ #define BOOST_REDIS_RUNNER_HPP #include -#include #include +#include #include #include #include #include #include -#include -#include -#include #include #include #include #include #include #include +#include +#include #include #include #include @@ -33,6 +32,9 @@ namespace boost::redis::detail void push_hello(config const& cfg, request& req); +// TODO: Can we avoid this whole function whose only purpose is to +// check for an error in the hello response and complete with an error +// so that the parallel group that starts it can exit? template struct hello_op { Runner* runner_ = nullptr; @@ -51,10 +53,15 @@ struct hello_op { conn_->async_exec(runner_->hello_req_, any_adapter(runner_->hello_resp_), std::move(self)); logger_.on_hello(ec, runner_->hello_resp_); - if (ec || runner_->has_error_in_response() || is_cancelled(self)) { - logger_.trace("hello-op: error/canceled. Exiting ..."); + if (ec) { conn_->cancel(operation::run); - self.complete(!!ec ? ec : asio::error::operation_aborted); + self.complete(ec); + return; + } + + if (runner_->has_error_in_response()) { + conn_->cancel(operation::run); + self.complete(error::resp3_hello); return; } @@ -71,6 +78,8 @@ class runner_op { Logger logger_; asio::coroutine coro_{}; + using order_t = std::array; + public: runner_op(Runner* runner, Connection* conn, Logger l) : runner_{runner} @@ -80,82 +89,113 @@ class runner_op { template void operator()( Self& self - , std::array order = {} + , order_t order = {} , system::error_code ec0 = {} , system::error_code ec1 = {} , system::error_code ec2 = {} - , std::size_t = 0) + , system::error_code ec3 = {} + , system::error_code ec4 = {}) { - BOOST_ASIO_CORO_REENTER (coro_) + BOOST_ASIO_CORO_REENTER (coro_) for (;;) { BOOST_ASIO_CORO_YIELD - runner_->resv_.async_resolve( - asio::prepend(std::move(self), std::array {})); + conn_->resv_.async_resolve(asio::prepend(std::move(self), order_t {})); - logger_.on_resolve(ec0, runner_->resv_.results()); + logger_.on_resolve(ec0, conn_->resv_.results()); - if (ec0 || redis::detail::is_cancelled(self)) { - self.complete(!!ec0 ? ec0 : asio::error::operation_aborted); + if (ec0) { + self.complete(ec0); return; } BOOST_ASIO_CORO_YIELD - runner_->ctor_.async_connect( + conn_->ctor_.async_connect( conn_->next_layer().next_layer(), - runner_->resv_.results(), - asio::prepend(std::move(self), std::array {})); + conn_->resv_.results(), + asio::prepend(std::move(self), order_t {})); - logger_.on_connect(ec0, runner_->ctor_.endpoint()); + logger_.on_connect(ec0, conn_->ctor_.endpoint()); - if (ec0 || redis::detail::is_cancelled(self)) { - self.complete(!!ec0 ? ec0 : asio::error::operation_aborted); + if (ec0) { + self.complete(ec0); return; } if (conn_->use_ssl()) { BOOST_ASIO_CORO_YIELD - runner_->hsher_.async_handshake( - conn_->next_layer(), - asio::prepend(std::move(self), - std::array {})); + conn_->next_layer().async_handshake( + asio::ssl::stream_base::client, + asio::prepend( + asio::cancel_after( + runner_->cfg_.ssl_handshake_timeout, + std::move(self) + ), + order_t {} + ) + ); logger_.on_ssl_handshake(ec0); - if (ec0 || redis::detail::is_cancelled(self)) { - self.complete(!!ec0 ? ec0 : asio::error::operation_aborted); + + if (ec0) { + self.complete(ec0); return; } } - // Note: Oder is important here because async_run might + conn_->reset(); + + // Note: Oder is important here because the writer might // trigger an async_write before the async_hello thereby - // causing authentication problems. + // causing an authentication problem. BOOST_ASIO_CORO_YIELD asio::experimental::make_parallel_group( [this](auto token) { return runner_->async_hello(*conn_, logger_, token); }, - [this](auto token) { return runner_->health_checker_.async_check_health(*conn_, logger_, token); }, - [this](auto token) { return conn_->async_run_lean(runner_->cfg_, logger_, token); } + [this](auto token) { return conn_->health_checker_.async_ping(*conn_, logger_, token); }, + [this](auto token) { return conn_->health_checker_.async_check_timeout(*conn_, logger_, token);}, + [this](auto token) { return conn_->reader(logger_, token);}, + [this](auto token) { return conn_->writer(logger_, token);} ).async_wait( asio::experimental::wait_for_one_error(), std::move(self)); - logger_.on_runner(ec0, ec1, ec2); + if (order[0] == 0 && !!ec0) { + self.complete(ec0); + return; + } - if (is_cancelled(self)) { - self.complete(asio::error::operation_aborted); + if (order[0] == 2 && ec2 == error::pong_timeout) { + self.complete(ec1); return; } - if (order[0] == 0 && !!ec0) { + // The receive operation must be cancelled because channel + // subscription does not survive a reconnection but requires + // re-subscription. + conn_->cancel(operation::receive); + + if (!conn_->will_reconnect()) { + conn_->cancel(operation::reconnection); + self.complete(ec3); + return; + } + + // It is safe to use the writer timer here because we are not + // connected. + conn_->writer_timer_.expires_after(conn_->cfg_.reconnect_wait_interval); + + BOOST_ASIO_CORO_YIELD + conn_->writer_timer_.async_wait(asio::prepend(std::move(self), order_t {})); + if (ec0) { self.complete(ec0); return; } - if (order[0] == 1 && ec1 == error::pong_timeout) { - self.complete(ec1); + if (!conn_->will_reconnect()) { + self.complete(asio::error::operation_aborted); return; } - self.complete(ec2); + conn_->reset_stream(); } } }; @@ -164,27 +204,12 @@ template class runner { public: runner(Executor ex, config cfg) - : resv_{ex} - , hsher_{ex} - , health_checker_{ex} - , cfg_{cfg} + : cfg_{cfg} { } - std::size_t cancel(operation op) - { - resv_.cancel(op); - hsher_.cancel(op); - health_checker_.cancel(op); - return 0U; - } - void set_config(config const& cfg) { cfg_ = cfg; - resv_.set_config(cfg); - ctor_.set_config(cfg); - hsher_.set_config(cfg); - health_checker_.set_config(cfg); } template @@ -196,12 +221,7 @@ class runner { >(runner_op{this, &conn, l}, token, conn); } - config const& get_config() const noexcept {return cfg_;} - private: - using resolver_type = resolver; - using handshaker_type = detail::handshaker; - using health_checker_type = health_checker; template friend class runner_op; template friend struct hello_op; @@ -240,10 +260,6 @@ class runner { return std::any_of(std::cbegin(hello_resp_.value()), std::cend(hello_resp_.value()), f); } - resolver_type resv_; - connector ctor_; - handshaker_type hsher_; - health_checker_type health_checker_; request hello_req_; generic_response hello_resp_; config cfg_; diff --git a/include/boost/redis/error.hpp b/include/boost/redis/error.hpp index 82ebce46..3ab56fce 100644 --- a/include/boost/redis/error.hpp +++ b/include/boost/redis/error.hpp @@ -81,6 +81,9 @@ enum class error /// Incompatible node depth. incompatible_node_depth, + + /// Resp3 hello command error + resp3_hello, }; /** \internal diff --git a/include/boost/redis/impl/error.ipp b/include/boost/redis/impl/error.ipp index 030c129f..a5ca7081 100644 --- a/include/boost/redis/impl/error.ipp +++ b/include/boost/redis/impl/error.ipp @@ -44,6 +44,7 @@ struct error_category_impl : system::error_category { case error::ssl_handshake_timeout: return "SSL handshake timeout."; case error::sync_receive_push_failed: return "Can't receive server push synchronously without blocking."; case error::incompatible_node_depth: return "Incompatible node depth."; + case error::resp3_hello: return "RESP3 handshake error (hello command)."; default: BOOST_ASSERT(false); return "Boost.Redis error."; } } diff --git a/include/boost/redis/impl/logger.ipp b/include/boost/redis/impl/logger.ipp index 0b8624c5..374fe919 100644 --- a/include/boost/redis/impl/logger.ipp +++ b/include/boost/redis/impl/logger.ipp @@ -25,7 +25,7 @@ void logger::on_resolve(system::error_code const& ec, asio::ip::tcp::resolver::r write_prefix(); - std::clog << "run-all-op: resolve addresses "; + std::clog << "resolve results: "; if (ec) { std::clog << ec.message() << std::endl; @@ -51,7 +51,7 @@ void logger::on_connect(system::error_code const& ec, asio::ip::tcp::endpoint co write_prefix(); - std::clog << "run-all-op: connected to endpoint "; + std::clog << "connected to "; if (ec) std::clog << ec.message() << std::endl; @@ -68,22 +68,7 @@ void logger::on_ssl_handshake(system::error_code const& ec) write_prefix(); - std::clog << "Runner: SSL handshake " << ec.message() << std::endl; -} - -void logger::on_connection_lost(system::error_code const& ec) -{ - if (level_ < level::info) - return; - - write_prefix(); - - if (ec) - std::clog << "Connection lost: " << ec.message(); - else - std::clog << "Connection lost."; - - std::clog << std::endl; + std::clog << "SSL handshake: " << ec.message() << std::endl; } void @@ -97,9 +82,9 @@ logger::on_write( write_prefix(); if (ec) - std::clog << "writer-op: " << ec.message(); + std::clog << "writer_op: " << ec.message(); else - std::clog << "writer-op: " << std::size(payload) << " bytes written."; + std::clog << "writer_op: " << std::size(payload) << " bytes written."; std::clog << std::endl; } @@ -112,23 +97,9 @@ void logger::on_read(system::error_code const& ec, std::size_t n) write_prefix(); if (ec) - std::clog << "reader-op: " << ec.message(); + std::clog << "reader_op: " << ec.message(); else - std::clog << "reader-op: " << n << " bytes read."; - - std::clog << std::endl; -} - -void logger::on_run(system::error_code const& reader_ec, system::error_code const& writer_ec) -{ - if (level_ < level::info) - return; - - write_prefix(); - - std::clog << "run-op: " - << reader_ec.message() << " (reader), " - << writer_ec.message() << " (writer)"; + std::clog << "reader_op: " << n << " bytes read."; std::clog << std::endl; } @@ -144,60 +115,34 @@ logger::on_hello( write_prefix(); if (ec) { - std::clog << "hello-op: " << ec.message(); + std::clog << "hello_op: " << ec.message(); if (resp.has_error()) std::clog << " (" << resp.error().diagnostic << ")"; } else { - std::clog << "hello-op: Success"; + std::clog << "hello_op: Success"; } std::clog << std::endl; } -void - logger::on_runner( - system::error_code const& hello, - system::error_code const& check_health, - system::error_code const& run) +void logger::trace(std::string_view message) { - if (level_ < level::info) - return; - - write_prefix(); - - std::clog << "hello: " - << hello.message() << " (async_hello), " - << check_health.message() << " (async_check_health) " - << run.message() << " (async_run_lean)."; - - std::clog << std::endl; -} - -void - logger::on_check_health( - system::error_code const& ping_ec, - system::error_code const& timeout_ec) -{ - if (level_ < level::info) + if (level_ < level::debug) return; write_prefix(); - std::clog << "check-health-op: " - << ping_ec.message() << " (async_ping), " - << timeout_ec.message() << " (async_check_timeout)."; - - std::clog << std::endl; + std::clog << message << std::endl; } -void logger::trace(std::string_view reason) +void logger::trace(std::string_view op, system::error_code const& ec) { if (level_ < level::debug) return; write_prefix(); - std::clog << reason << std::endl; + std::clog << op << ": " << ec.message() << std::endl; } } // boost::redis diff --git a/include/boost/redis/impl/runner.ipp b/include/boost/redis/impl/resp3_handshaker.ipp similarity index 94% rename from include/boost/redis/impl/runner.ipp rename to include/boost/redis/impl/resp3_handshaker.ipp index 293ad92e..e18bf928 100644 --- a/include/boost/redis/impl/runner.ipp +++ b/include/boost/redis/impl/resp3_handshaker.ipp @@ -4,7 +4,7 @@ * accompanying file LICENSE.txt) */ -#include +#include namespace boost::redis::detail { diff --git a/include/boost/redis/logger.hpp b/include/boost/redis/logger.hpp index b102f022..7558436e 100644 --- a/include/boost/redis/logger.hpp +++ b/include/boost/redis/logger.hpp @@ -91,13 +91,6 @@ class logger { */ void on_ssl_handshake(system::error_code const& ec); - /** @brief Called when the connection is lost. - * @ingroup high-level-api - * - * @param ec Error returned when the connection is lost. - */ - void on_connection_lost(system::error_code const& ec); - /** @brief Called when the write operation completes. * @ingroup high-level-api * @@ -114,14 +107,6 @@ class logger { */ void on_read(system::error_code const& ec, std::size_t n); - /** @brief Called when the run operation completes. - * @ingroup high-level-api - * - * @param reader_ec Error code returned by the read operation. - * @param writer_ec Error code returned by the write operation. - */ - void on_run(system::error_code const& reader_ec, system::error_code const& writer_ec); - /** @brief Called when the `HELLO` request completes. * @ingroup high-level-api * @@ -140,25 +125,8 @@ class logger { prefix_ = prefix; } - /** @brief Called when the runner operation completes. - * @ingroup high-level-api - * - * @param hello_ec Error code returned by the health checker operation. - * @param health_check_ec Error code returned by the health checker operation. - * @param run_all_ec Error code returned by the run_all operation. - */ - void - on_runner( - system::error_code const& hello, - system::error_code const& health_check, - system::error_code const& run); - - void - on_check_health( - system::error_code const& ping_ec, - system::error_code const& check_timeout_ec); - - void trace(std::string_view reason); + void trace(std::string_view message); + void trace(std::string_view op, system::error_code const& ec); private: void write_prefix(); diff --git a/include/boost/redis/src.hpp b/include/boost/redis/src.hpp index 2d5cac76..5ba662e0 100644 --- a/include/boost/redis/src.hpp +++ b/include/boost/redis/src.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 14dcc0ca..5ee07d2f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,6 @@ make_test(test_conn_reconnect 20) make_test(test_conn_exec_cancel 20) make_test(test_conn_exec_cancel2 20) make_test(test_conn_echo_stress 20) -make_test(test_conn_run_cancel 20) make_test(test_any_adapter 17) make_test(test_issue_50 20) make_test(test_issue_181 17) diff --git a/test/common.hpp b/test/common.hpp index a0e30d6a..83378d5c 100644 --- a/test/common.hpp +++ b/test/common.hpp @@ -24,5 +24,5 @@ run( boost::redis::config cfg = make_test_config(), boost::system::error_code ec = boost::asio::error::operation_aborted, boost::redis::operation op = boost::redis::operation::receive, - boost::redis::logger::level l = boost::redis::logger::level::disabled); + boost::redis::logger::level l = boost::redis::logger::level::debug); diff --git a/test/test_conn_exec_retry.cpp b/test/test_conn_exec_retry.cpp index 99f68c39..2ea402f9 100644 --- a/test/test_conn_exec_retry.cpp +++ b/test/test_conn_exec_retry.cpp @@ -75,8 +75,13 @@ BOOST_AUTO_TEST_CASE(request_retry_false) conn->async_exec(req0, ignore, c0); auto cfg = make_test_config(); - cfg.health_check_interval = 5s; - run(conn); + conn->async_run(cfg, {boost::redis::logger::level::debug}, + [&](boost::system::error_code const& ec) + { + std::cout << "async_run: " << ec.message() << std::endl; + conn->cancel(); + } + ); ioc.run(); } diff --git a/test/test_conn_quit.cpp b/test/test_conn_quit.cpp index 9d5dd2f3..5e4759c2 100644 --- a/test/test_conn_quit.cpp +++ b/test/test_conn_quit.cpp @@ -20,24 +20,6 @@ using boost::redis::response; using boost::redis::ignore; using namespace std::chrono_literals; -BOOST_AUTO_TEST_CASE(test_eof_no_error) -{ - request req; - req.get_config().cancel_on_connection_lost = false; - req.push("QUIT"); - - net::io_context ioc; - auto conn = std::make_shared(ioc); - - conn->async_exec(req, ignore, [&](auto ec, auto) { - BOOST_TEST(!ec); - conn->cancel(operation::reconnection); - }); - - run(conn); - ioc.run(); -} - // Test if quit causes async_run to exit. BOOST_AUTO_TEST_CASE(test_async_run_exits) { diff --git a/test/test_conn_run_cancel.cpp b/test/test_conn_run_cancel.cpp deleted file mode 100644 index f4f0b73d..00000000 --- a/test/test_conn_run_cancel.cpp +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright (c) 2018-2022 Marcelo Zimbres Silva (mzimbres@gmail.com) - * - * Distributed under the Boost Software License, Version 1.0. (See - * accompanying file LICENSE.txt) - */ - -#include -#include -#include -#include -#define BOOST_TEST_MODULE conn-run-cancel -#include -#include -#include "common.hpp" - -#ifdef BOOST_ASIO_HAS_CO_AWAIT -#include -#include - -namespace net = boost::asio; - -using boost::redis::operation; -using boost::redis::connection; -using boost::system::error_code; -using net::as_tuple; -using boost::redis::request; -using boost::redis::response; -using boost::redis::ignore; -using boost::redis::logger; -using namespace std::chrono_literals; - -using namespace net::experimental::awaitable_operators; - -auto async_cancel_run_with_timer() -> net::awaitable -{ - auto ex = co_await net::this_coro::executor; - connection conn{ex}; - - net::steady_timer st{ex}; - st.expires_after(1s); - - error_code ec1, ec2; - auto cfg = make_test_config(); - logger l; - co_await (conn.async_run(cfg, l, redir(ec1)) || st.async_wait(redir(ec2))); - - BOOST_CHECK_EQUAL(ec1, boost::asio::error::operation_aborted); - BOOST_TEST(!ec2); -} - -BOOST_AUTO_TEST_CASE(cancel_run_with_timer) -{ - net::io_context ioc; - net::co_spawn(ioc.get_executor(), async_cancel_run_with_timer(), net::detached); - ioc.run(); -} - -auto -async_check_cancellation_not_missed(int n, std::chrono::milliseconds ms) -> net::awaitable -{ - auto ex = co_await net::this_coro::executor; - connection conn{ex}; - - net::steady_timer timer{ex}; - - for (auto i = 0; i < n; ++i) { - timer.expires_after(ms); - error_code ec1, ec2; - auto cfg = make_test_config(); - logger l; - co_await (conn.async_run(cfg, l, redir(ec1)) || timer.async_wait(redir(ec2))); - BOOST_CHECK_EQUAL(ec1, boost::asio::error::operation_aborted); - std::cout << "Counter: " << i << std::endl; - } -} - -// See PR #29 -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_0) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(10, std::chrono::milliseconds{0}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_2) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{2}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_8) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{8}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_16) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{16}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_32) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{32}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_64) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{64}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_128) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{128}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_256) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{256}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_512) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{512}), net::detached); - ioc.run(); -} - -BOOST_AUTO_TEST_CASE(check_implicit_cancel_not_missed_1024) -{ - net::io_context ioc; - net::co_spawn(ioc, async_check_cancellation_not_missed(20, std::chrono::milliseconds{1024}), net::detached); - ioc.run(); -} - -#else -BOOST_AUTO_TEST_CASE(dummy) -{ - BOOST_TEST(true); -} -#endif diff --git a/test/test_low_level.cpp b/test/test_low_level.cpp index dadbe454..0dcd14bc 100644 --- a/test/test_low_level.cpp +++ b/test/test_low_level.cpp @@ -522,6 +522,7 @@ BOOST_AUTO_TEST_CASE(cover_error) check_error("boost.redis", boost::redis::error::ssl_handshake_timeout); check_error("boost.redis", boost::redis::error::sync_receive_push_failed); check_error("boost.redis", boost::redis::error::incompatible_node_depth); + check_error("boost.redis", boost::redis::error::resp3_hello); } std::string get_type_as_str(boost::redis::resp3::type t) diff --git a/test/test_low_level_sync_sans_io.cpp b/test/test_low_level_sync_sans_io.cpp index 6f6acfe0..d36eccde 100644 --- a/test/test_low_level_sync_sans_io.cpp +++ b/test/test_low_level_sync_sans_io.cpp @@ -1,10 +1,10 @@ -/* Copyright (c) 2018-2022 Marcelo Zimbres Silva (mzimbres@gmail.com) +/* Copyright (c) 2018-2024 Marcelo Zimbres Silva (mzimbres@gmail.com) * * Distributed under the Boost Software License, Version 1.0. (See * accompanying file LICENSE.txt) */ -#include +#include #include #include #define BOOST_TEST_MODULE conn-quit From 9a48633bdfcf6f4c9640ea635ed58564d571c0e1 Mon Sep 17 00:00:00 2001 From: Marcelo Zimbres Date: Sun, 22 Dec 2024 21:20:51 +0100 Subject: [PATCH 2/2] Removes connection_base class. --- include/boost/redis/adapter/any_adapter.hpp | 4 +- include/boost/redis/connection.hpp | 1017 +++++++++++++++- .../boost/redis/detail/connection_base.hpp | 1059 ----------------- test/test_issue_181.cpp | 4 +- 4 files changed, 982 insertions(+), 1102 deletions(-) delete mode 100644 include/boost/redis/detail/connection_base.hpp diff --git a/include/boost/redis/adapter/any_adapter.hpp b/include/boost/redis/adapter/any_adapter.hpp index bf14d6e5..f317fc1b 100644 --- a/include/boost/redis/adapter/any_adapter.hpp +++ b/include/boost/redis/adapter/any_adapter.hpp @@ -22,7 +22,7 @@ namespace detail { // Forward decl template -class connection_base; +class basic_connection; } @@ -58,7 +58,7 @@ class any_adapter } template - friend class detail::connection_base; + friend class basic_connection; public: /** diff --git a/include/boost/redis/connection.hpp b/include/boost/redis/connection.hpp index b9c3654a..2efd5e5e 100644 --- a/include/boost/redis/connection.hpp +++ b/include/boost/redis/connection.hpp @@ -7,21 +7,461 @@ #ifndef BOOST_REDIS_CONNECTION_HPP #define BOOST_REDIS_CONNECTION_HPP +#include #include -#include -#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include #include -#include -#include +#include +#include + +#include +#include +#include #include #include +#include +#include #include +#include +#include +#include namespace boost::redis { +namespace detail +{ + +template +std::string_view buffer_view(DynamicBuffer buf) noexcept +{ + char const* start = static_cast(buf.data(0, buf.size()).data()); + return std::string_view{start, std::size(buf)}; +} + +template +class append_some_op { +private: + AsyncReadStream& stream_; + DynamicBuffer buf_; + std::size_t size_ = 0; + std::size_t tmp_ = 0; + asio::coroutine coro_{}; + +public: + append_some_op(AsyncReadStream& stream, DynamicBuffer buf, std::size_t size) + : stream_ {stream} + , buf_ {std::move(buf)} + , size_{size} + { } + + template + void operator()( Self& self + , system::error_code ec = {} + , std::size_t n = 0) + { + BOOST_ASIO_CORO_REENTER (coro_) + { + tmp_ = buf_.size(); + buf_.grow(size_); + + BOOST_ASIO_CORO_YIELD + stream_.async_read_some(buf_.data(tmp_, size_), std::move(self)); + if (ec) { + self.complete(ec, 0); + return; + } + + buf_.shrink(buf_.size() - tmp_ - n); + self.complete({}, n); + } + } +}; + +template +auto +async_append_some( + AsyncReadStream& stream, + DynamicBuffer buffer, + std::size_t size, + CompletionToken&& token) +{ + return asio::async_compose + < CompletionToken + , void(system::error_code, std::size_t) + >(append_some_op {stream, buffer, size}, token, stream); +} + +template +struct exec_op { + using req_info_type = typename Conn::req_info; + using adapter_type = typename Conn::adapter_type; + + Conn* conn_ = nullptr; + std::shared_ptr info_ = nullptr; + asio::coroutine coro{}; + + template + void operator()(Self& self , system::error_code = {}, std::size_t = 0) + { + BOOST_ASIO_CORO_REENTER (coro) + { + // Check whether the user wants to wait for the connection to + // be stablished. + if (info_->req_->get_config().cancel_if_not_connected && !conn_->is_open()) { + BOOST_ASIO_CORO_YIELD + asio::dispatch( + asio::get_associated_immediate_executor(self, self.get_io_executor()), + std::move(self)); + return self.complete(error::not_connected, 0); + } + + conn_->add_request_info(info_); + +EXEC_OP_WAIT: + BOOST_ASIO_CORO_YIELD + info_->async_wait(std::move(self)); + + if (info_->ec_) { + self.complete(info_->ec_, 0); + return; + } + + if (info_->stop_requested()) { + // Don't have to call remove_request as it has already + // been by cancel(exec). + return self.complete(asio::error::operation_aborted, 0); + } + + if (is_cancelled(self)) { + if (!info_->is_waiting()) { + using c_t = asio::cancellation_type; + auto const c = self.get_cancellation_state().cancelled(); + if ((c & c_t::terminal) != c_t::none) { + // Cancellation requires closing the connection + // otherwise it stays in inconsistent state. + conn_->cancel(operation::run); + return self.complete(asio::error::operation_aborted, 0); + } else { + // Can't implement other cancelation types, ignoring. + self.get_cancellation_state().clear(); + + // TODO: Find out a better way to ignore + // cancelation. + goto EXEC_OP_WAIT; + } + } else { + // Cancelation can be honored. + conn_->remove_request(info_); + self.complete(asio::error::operation_aborted, 0); + return; + } + } + + self.complete(info_->ec_, info_->read_size_); + } + } +}; + +template +struct writer_op { + Conn* conn_; + Logger logger_; + asio::coroutine coro{}; + + template + void operator()( Self& self + , system::error_code ec = {} + , std::size_t n = 0) + { + ignore_unused(n); + + BOOST_ASIO_CORO_REENTER (coro) for (;;) + { + while (conn_->coalesce_requests()) { + if (conn_->use_ssl()) + BOOST_ASIO_CORO_YIELD asio::async_write(conn_->next_layer(), asio::buffer(conn_->write_buffer_), std::move(self)); + else + BOOST_ASIO_CORO_YIELD asio::async_write(conn_->next_layer().next_layer(), asio::buffer(conn_->write_buffer_), std::move(self)); + + logger_.on_write(ec, conn_->write_buffer_); + + if (ec) { + logger_.trace("writer_op (1)", ec); + conn_->cancel(operation::run); + self.complete(ec); + return; + } + + conn_->on_write(); + + // A socket.close() may have been called while a + // successful write might had already been queued, so we + // have to check here before proceeding. + if (!conn_->is_open()) { + logger_.trace("writer_op (2): connection is closed."); + self.complete({}); + return; + } + } + + BOOST_ASIO_CORO_YIELD + conn_->writer_timer_.async_wait(std::move(self)); + if (!conn_->is_open()) { + logger_.trace("writer_op (3): connection is closed."); + // Notice this is not an error of the op, stoping was + // requested from the outside, so we complete with + // success. + self.complete({}); + return; + } + } + } +}; + +template +struct reader_op { + using parse_result = typename Conn::parse_result; + using parse_ret_type = typename Conn::parse_ret_type; + Conn* conn_; + Logger logger_; + parse_ret_type res_{parse_result::resp, 0}; + asio::coroutine coro{}; + + template + void operator()( Self& self + , system::error_code ec = {} + , std::size_t n = 0) + { + ignore_unused(n); + + BOOST_ASIO_CORO_REENTER (coro) for (;;) + { + // Appends some data to the buffer if necessary. + if ((res_.first == parse_result::needs_more) || std::empty(conn_->read_buffer_)) { + if (conn_->use_ssl()) { + BOOST_ASIO_CORO_YIELD + async_append_some( + conn_->next_layer(), + conn_->dbuf_, + conn_->get_suggested_buffer_growth(), + std::move(self)); + } else { + BOOST_ASIO_CORO_YIELD + async_append_some( + conn_->next_layer().next_layer(), + conn_->dbuf_, + conn_->get_suggested_buffer_growth(), + std::move(self)); + } + + logger_.on_read(ec, n); + + // The connection is not viable after an error. + if (ec) { + logger_.trace("reader_op (1)", ec); + conn_->cancel(operation::run); + self.complete(ec); + return; + } + + // Somebody might have canceled implicitly or explicitly + // while we were suspended and after queueing so we have to + // check. + if (!conn_->is_open()) { + logger_.trace("reader_op (2): connection is closed."); + self.complete(ec); + return; + } + } + + res_ = conn_->on_read(buffer_view(conn_->dbuf_), ec); + if (ec) { + logger_.trace("reader_op (3)", ec); + conn_->cancel(operation::run); + self.complete(ec); + return; + } + + if (res_.first == parse_result::push) { + if (!conn_->receive_channel_.try_send(ec, res_.second)) { + BOOST_ASIO_CORO_YIELD + conn_->receive_channel_.async_send(ec, res_.second, std::move(self)); + } + + if (ec) { + logger_.trace("reader_op (4)", ec); + conn_->cancel(operation::run); + self.complete(ec); + return; + } + + if (!conn_->is_open()) { + logger_.trace("reader_op (5): connection is closed."); + self.complete(asio::error::operation_aborted); + return; + } + + } + } + } +}; + +template +class run_op { +private: + Conn* conn_ = nullptr; + Logger logger_; + asio::coroutine coro_{}; + + using order_t = std::array; + +public: + run_op(Conn* conn, Logger l) + : conn_{conn} + , logger_{l} + {} + + template + void operator()( Self& self + , order_t order = {} + , system::error_code ec0 = {} + , system::error_code ec1 = {} + , system::error_code ec2 = {} + , system::error_code ec3 = {} + , system::error_code ec4 = {}) + { + BOOST_ASIO_CORO_REENTER (coro_) for (;;) + { + BOOST_ASIO_CORO_YIELD + conn_->resv_.async_resolve(asio::prepend(std::move(self), order_t {})); + + logger_.on_resolve(ec0, conn_->resv_.results()); + + if (ec0) { + self.complete(ec0); + return; + } + + BOOST_ASIO_CORO_YIELD + conn_->ctor_.async_connect( + conn_->next_layer().next_layer(), + conn_->resv_.results(), + asio::prepend(std::move(self), order_t {})); + + logger_.on_connect(ec0, conn_->ctor_.endpoint()); + + if (ec0) { + self.complete(ec0); + return; + } + + if (conn_->use_ssl()) { + BOOST_ASIO_CORO_YIELD + conn_->next_layer().async_handshake( + asio::ssl::stream_base::client, + asio::prepend( + asio::cancel_after( + conn_->cfg_.ssl_handshake_timeout, + std::move(self) + ), + order_t {} + ) + ); + + logger_.on_ssl_handshake(ec0); + + if (ec0) { + self.complete(ec0); + return; + } + } + + conn_->reset(); + + // Note: Oder is important here because the writer might + // trigger an async_write before the async_hello thereby + // causing an authentication problem. + BOOST_ASIO_CORO_YIELD + asio::experimental::make_parallel_group( + [this](auto token) { return conn_->handshaker_.async_hello(*conn_, logger_, token); }, + [this](auto token) { return conn_->health_checker_.async_ping(*conn_, logger_, token); }, + [this](auto token) { return conn_->health_checker_.async_check_timeout(*conn_, logger_, token);}, + [this](auto token) { return conn_->reader(logger_, token);}, + [this](auto token) { return conn_->writer(logger_, token);} + ).async_wait( + asio::experimental::wait_for_one_error(), + std::move(self)); + + if (order[0] == 0 && !!ec0) { + self.complete(ec0); + return; + } + + if (order[0] == 2 && ec2 == error::pong_timeout) { + self.complete(ec1); + return; + } + + // The receive operation must be cancelled because channel + // subscription does not survive a reconnection but requires + // re-subscription. + conn_->cancel(operation::receive); + + if (!conn_->will_reconnect()) { + conn_->cancel(operation::reconnection); + self.complete(ec3); + return; + } + + // It is safe to use the writer timer here because we are not + // connected. + conn_->writer_timer_.expires_after(conn_->cfg_.reconnect_wait_interval); + + BOOST_ASIO_CORO_YIELD + conn_->writer_timer_.async_wait(asio::prepend(std::move(self), order_t {})); + if (ec0) { + self.complete(ec0); + return; + } + + if (!conn_->will_reconnect()) { + self.complete(asio::error::operation_aborted); + return; + } + + conn_->reset_stream(); + } + } +}; + +} // boost::redis::detail /** @brief A SSL connection to the Redis server. * @ingroup high-level-api @@ -36,12 +476,17 @@ namespace boost::redis { template class basic_connection { public: - /// Executor type. + using this_type = basic_connection; + + /// Type of the next layer + using next_layer_type = asio::ssl::stream>; + + /// Executor type using executor_type = Executor; - /// Returns the underlying executor. + /// Returns the associated executor. executor_type get_executor() noexcept - { return impl_.get_executor(); } + {return writer_timer_.get_executor();} /// Rebinds the socket type to another executor. template @@ -63,10 +508,19 @@ class basic_connection { executor_type ex, asio::ssl::context ctx = asio::ssl::context{asio::ssl::context::tlsv12_client}, std::size_t max_read_size = (std::numeric_limits::max)()) - : impl_{ex, std::move(ctx), max_read_size} - { } + : ctx_{std::move(ctx)} + , stream_{std::make_unique(ex, ctx_)} + , writer_timer_{ex} + , receive_channel_{ex, 256} + , resv_{ex} + , health_checker_{ex} + , dbuf_{read_buffer_, max_read_size} + { + set_receive_response(ignore); + writer_timer_.expires_at((std::chrono::steady_clock::time_point::max)()); + } - /// Contructs from a context. + /// Constructs from a context. explicit basic_connection( asio::io_context& ioc, @@ -121,7 +575,17 @@ class basic_connection { Logger l = Logger{}, CompletionToken token = CompletionToken{}) { - return impl_.async_run(cfg, l, std::move(token)); + cfg_ = cfg; + resv_.set_config(cfg); + ctor_.set_config(cfg); + health_checker_.set_config(cfg); + handshaker_.set_config(cfg); + l.set_prefix(cfg.log_prefix); + + return asio::async_compose + < CompletionToken + , void(system::error_code) + >(detail::run_op{this, l}, token, writer_timer_); } /** @brief Receives server side pushes asynchronously. @@ -148,9 +612,8 @@ class basic_connection { */ template > auto async_receive(CompletionToken token = CompletionToken{}) - { return impl_.async_receive(std::move(token)); } + { return receive_channel_.async_receive(std::move(token)); } - /** @brief Receives server pushes synchronously without blocking. * * Receives a server push synchronously by calling `try_receive` on @@ -164,20 +627,22 @@ class basic_connection { */ std::size_t receive(system::error_code& ec) { - return impl_.receive(ec); - } + std::size_t size = 0; - template < - class Response = ignore_t, - class CompletionToken = asio::default_completion_token_t - > - [[deprecated("Set the response with set_receive_response and use the other overload.")]] - auto - async_receive( - Response& response, - CompletionToken token = CompletionToken{}) - { - return impl_.async_receive(response, token); + auto f = [&](system::error_code const& ec2, std::size_t n) + { + ec = ec2; + size = n; + }; + + auto const res = receive_channel_.try_receive(f); + if (ec) + return 0; + + if (!res) + ec = error::sync_receive_push_failed; + + return size; } /** @brief Executes commands on the Redis server asynchronously. @@ -213,7 +678,7 @@ class basic_connection { Response& resp = ignore, CompletionToken&& token = CompletionToken{}) { - return impl_.async_exec(req, any_adapter(resp), std::forward(token)); + return this->async_exec(req, any_adapter(resp), std::forward(token)); } /** @copydoc async_exec @@ -228,7 +693,15 @@ class basic_connection { any_adapter adapter, CompletionToken&& token = CompletionToken{}) { - return impl_.async_exec(req, std::move(adapter), std::forward(token)); + auto& adapter_impl = adapter.impl_; + BOOST_ASSERT_MSG(req.get_expected_responses() <= adapter_impl.supported_response_size, "Request and response have incompatible sizes."); + + auto info = std::make_shared(req, std::move(adapter_impl.adapt_fn), get_executor()); + + return asio::async_compose + < CompletionToken + , void(system::error_code, std::size_t) + >(detail::exec_op{this, info}, token, writer_timer_); } /** @brief Cancel operations. @@ -243,39 +716,505 @@ class basic_connection { * @param op: The operation to be cancelled. */ void cancel(operation op = operation::all) - { impl_.cancel(op); } + { + switch (op) { + case operation::resolve: + resv_.cancel(); + break; + case operation::exec: + cancel_unwritten_requests(); + break; + case operation::reconnection: + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + break; + case operation::run: + cancel_run(); + break; + case operation::receive: + receive_channel_.cancel(); + break; + case operation::health_check: + health_checker_.cancel(); + break; + case operation::all: + resv_.cancel(); + cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); + health_checker_.cancel(); + cancel_run(); // run + receive_channel_.cancel(); // receive + cancel_unwritten_requests(); // exec + break; + default: /* ignore */; + } + } + + auto run_is_canceled() const noexcept + { return cancel_run_called_; } /// Returns true if the connection was canceled. bool will_reconnect() const noexcept - { return impl_.will_reconnect();} + { return cfg_.reconnect_wait_interval != std::chrono::seconds::zero();} /// Returns the ssl context. auto const& get_ssl_context() const noexcept - { return impl_.get_ssl_context();} + { return ctx_;} /// Resets the underlying stream. void reset_stream() - { impl_.reset_stream(); } + { stream_ = std::make_unique(writer_timer_.get_executor(), ctx_); } /// Returns a reference to the next layer. auto& next_layer() noexcept - { return impl_.next_layer(); } + { return *stream_; } /// Returns a const reference to the next layer. auto const& next_layer() const noexcept - { return impl_.next_layer(); } - + { return *stream_; } /// Sets the response object of `async_receive` operations. template void set_receive_response(Response& response) - { impl_.set_receive_response(response); } + { + using namespace boost::redis::adapter; + auto g = boost_redis_adapt(response); + receive_adapter_ = adapter::detail::make_adapter_wrapper(g); + } /// Returns connection usage information. usage get_usage() const noexcept - { return impl_.get_usage(); } + { return usage_; } private: - detail::connection_base impl_; + using clock_type = std::chrono::steady_clock; + using clock_traits_type = asio::wait_traits; + using timer_type = asio::basic_waitable_timer; + + using receive_channel_type = asio::experimental::channel; + using resolver_type = detail::resolver; + using health_checker_type = detail::health_checker; + using resp3_handshaker_type = detail::resp3_handshaker; + using adapter_type = std::function const&, system::error_code&)>; + using receiver_adapter_type = std::function const&, system::error_code&)>; + using exec_notifier_type = receive_channel_type; + + auto use_ssl() const noexcept + { return cfg_.use_ssl;} + + auto cancel_on_conn_lost() -> std::size_t + { + // Must return false if the request should be removed. + auto cond = [](auto const& ptr) + { + BOOST_ASSERT(ptr != nullptr); + + if (ptr->is_waiting()) { + return !ptr->req_->get_config().cancel_on_connection_lost; + } else { + return !ptr->req_->get_config().cancel_if_unresponded; + } + }; + + auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), cond); + + auto const ret = std::distance(point, std::end(reqs_)); + + std::for_each(point, std::end(reqs_), [](auto const& ptr) { + ptr->stop(); + }); + + reqs_.erase(point, std::end(reqs_)); + + std::for_each(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { + return ptr->mark_waiting(); + }); + + return ret; + } + + auto cancel_unwritten_requests() -> std::size_t + { + auto f = [](auto const& ptr) + { + BOOST_ASSERT(ptr != nullptr); + return !ptr->is_waiting(); + }; + + auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), f); + + auto const ret = std::distance(point, std::end(reqs_)); + + std::for_each(point, std::end(reqs_), [](auto const& ptr) { + ptr->stop(); + }); + + reqs_.erase(point, std::end(reqs_)); + return ret; + } + + void cancel_run() + { + // Protects the code below from being called more than + // once, see https://github.com/boostorg/redis/issues/181 + if (std::exchange(cancel_run_called_, true)) { + return; + } + + close(); + writer_timer_.cancel(); + receive_channel_.cancel(); + cancel_on_conn_lost(); + } + + void on_write() + { + // We have to clear the payload right after writing it to use it + // as a flag that informs there is no ongoing write. + write_buffer_.clear(); + + // Notice this must come before the for-each below. + cancel_push_requests(); + + // There is small optimization possible here: traverse only the + // partition of unwritten requests instead of them all. + std::for_each(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { + BOOST_ASSERT_MSG(ptr != nullptr, "Expects non-null pointer."); + if (ptr->is_staged()) { + ptr->mark_written(); + } + }); + } + + struct req_info { + public: + using node_type = resp3::basic_node; + using wrapped_adapter_type = std::function; + + explicit req_info(request const& req, adapter_type adapter, executor_type ex) + : notifier_{ex, 1} + , req_{&req} + , adapter_{} + , expected_responses_{req.get_expected_responses()} + , status_{status::waiting} + , ec_{{}} + , read_size_{0} + { + adapter_ = [this, adapter](node_type const& nd, system::error_code& ec) + { + auto const i = req_->get_expected_responses() - expected_responses_; + adapter(i, nd, ec); + }; + } + + auto proceed() + { + notifier_.try_send(std::error_code{}, 0); + } + + void stop() + { + notifier_.close(); + } + + [[nodiscard]] auto is_waiting() const noexcept + { return status_ == status::waiting; } + + [[nodiscard]] auto is_written() const noexcept + { return status_ == status::written; } + + [[nodiscard]] auto is_staged() const noexcept + { return status_ == status::staged; } + + void mark_written() noexcept + { status_ = status::written; } + + void mark_staged() noexcept + { status_ = status::staged; } + + void mark_waiting() noexcept + { status_ = status::waiting; } + + [[nodiscard]] auto stop_requested() const noexcept + { return !notifier_.is_open();} + + template + auto async_wait(CompletionToken token) + { + return notifier_.async_receive(std::move(token)); + } + + //private: + enum class status + { waiting + , staged + , written + }; + + exec_notifier_type notifier_; + request const* req_; + wrapped_adapter_type adapter_; + + // Contains the number of commands that haven't been read yet. + std::size_t expected_responses_; + status status_; + + system::error_code ec_; + std::size_t read_size_; + }; + + void remove_request(std::shared_ptr const& info) + { + reqs_.erase(std::remove(std::begin(reqs_), std::end(reqs_), info)); + } + + using reqs_type = std::deque>; + + template friend struct detail::reader_op; + template friend struct detail::writer_op; + template friend struct detail::exec_op; + template friend class detail::run_op; + + void cancel_push_requests() + { + auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { + return !(ptr->is_staged() && ptr->req_->get_expected_responses() == 0); + }); + + std::for_each(point, std::end(reqs_), [](auto const& ptr) { + ptr->proceed(); + }); + + reqs_.erase(point, std::end(reqs_)); + } + + [[nodiscard]] bool is_writing() const noexcept + { + return !write_buffer_.empty(); + } + + void add_request_info(std::shared_ptr const& info) + { + reqs_.push_back(info); + + if (info->req_->has_hello_priority()) { + auto rend = std::partition_point(std::rbegin(reqs_), std::rend(reqs_), [](auto const& e) { + return e->is_waiting(); + }); + + std::rotate(std::rbegin(reqs_), std::rbegin(reqs_) + 1, rend); + } + + if (is_open() && !is_writing()) + writer_timer_.cancel(); + } + + template + auto reader(Logger l, CompletionToken&& token) + { + return asio::async_compose + < CompletionToken + , void(system::error_code) + >(detail::reader_op{this, l}, token, writer_timer_); + } + + template + auto writer(Logger l, CompletionToken&& token) + { + return asio::async_compose + < CompletionToken + , void(system::error_code) + >(detail::writer_op{this, l}, token, writer_timer_); + } + + [[nodiscard]] bool coalesce_requests() + { + // Coalesces the requests and marks them staged. After a + // successful write staged requests will be marked as written. + auto const point = std::partition_point(std::cbegin(reqs_), std::cend(reqs_), [](auto const& ri) { + return !ri->is_waiting(); + }); + + std::for_each(point, std::cend(reqs_), [this](auto const& ri) { + // Stage the request. + write_buffer_ += ri->req_->payload(); + ri->mark_staged(); + usage_.commands_sent += ri->expected_responses_; + }); + + usage_.bytes_sent += std::size(write_buffer_); + + return point != std::cend(reqs_); + } + + bool is_waiting_response() const noexcept + { + if (std::empty(reqs_)) + return false; + + // Under load and on low-latency networks we might start + // receiving responses before the write operation completed and + // the request is still maked as staged and not written. See + // https://github.com/boostorg/redis/issues/170 + return !reqs_.front()->is_waiting(); + } + + void close() + { + if (stream_->next_layer().is_open()) { + system::error_code ec; + stream_->next_layer().close(ec); + } + } + + auto is_open() const noexcept { return stream_->next_layer().is_open(); } + auto& lowest_layer() noexcept { return stream_->lowest_layer(); } + + auto is_next_push() + { + BOOST_ASSERT(!read_buffer_.empty()); + + // Useful links to understand the heuristics below. + // + // - https://github.com/redis/redis/issues/11784 + // - https://github.com/redis/redis/issues/6426 + // - https://github.com/boostorg/redis/issues/170 + + // The message's resp3 type is a push. + if (resp3::to_type(read_buffer_.front()) == resp3::type::push) + return true; + + // This is non-push type and the requests queue is empty. I have + // noticed this is possible, for example with -MISCONF. I don't + // know why they are not sent with a push type so we can + // distinguish them from responses to commands. If we are lucky + // enough to receive them when the command queue is empty they + // can be treated as server pushes, otherwise it is impossible + // to handle them properly + if (reqs_.empty()) + return true; + + // The request does not expect any response but we got one. This + // may happen if for example, subscribe with wrong syntax. + if (reqs_.front()->expected_responses_ == 0) + return true; + + // Added to deal with MONITOR and also to fix PR170 which + // happens under load and on low-latency networks, where we + // might start receiving responses before the write operation + // completed and the request is still maked as staged and not + // written. + return reqs_.front()->is_waiting(); + } + + auto get_suggested_buffer_growth() const noexcept + { + return parser_.get_suggested_buffer_growth(4096); + } + + enum class parse_result { needs_more, push, resp }; + + using parse_ret_type = std::pair; + + parse_ret_type on_finish_parsing(parse_result t) + { + if (t == parse_result::push) { + usage_.pushes_received += 1; + usage_.push_bytes_received += parser_.get_consumed(); + } else { + usage_.responses_received += 1; + usage_.response_bytes_received += parser_.get_consumed(); + } + + on_push_ = false; + dbuf_.consume(parser_.get_consumed()); + auto const res = std::make_pair(t, parser_.get_consumed()); + parser_.reset(); + return res; + } + + parse_ret_type on_read(std::string_view data, system::error_code& ec) + { + // We arrive here in two states: + // + // 1. While we are parsing a message. In this case we + // don't want to determine the type of the message in the + // buffer (i.e. response vs push) but leave it untouched + // until the parsing of a complete message ends. + // + // 2. On a new message, in which case we have to determine + // whether the next messag is a push or a response. + // + if (!on_push_) // Prepare for new message. + on_push_ = is_next_push(); + + if (on_push_) { + if (!resp3::parse(parser_, data, receive_adapter_, ec)) + return std::make_pair(parse_result::needs_more, 0); + + if (ec) + return std::make_pair(parse_result::push, 0); + + return on_finish_parsing(parse_result::push); + } + + BOOST_ASSERT_MSG(is_waiting_response(), "Not waiting for a response (using MONITOR command perhaps?)"); + BOOST_ASSERT(!reqs_.empty()); + BOOST_ASSERT(reqs_.front() != nullptr); + BOOST_ASSERT(reqs_.front()->expected_responses_ != 0); + + if (!resp3::parse(parser_, data, reqs_.front()->adapter_, ec)) + return std::make_pair(parse_result::needs_more, 0); + + if (ec) { + reqs_.front()->ec_ = ec; + reqs_.front()->proceed(); + return std::make_pair(parse_result::resp, 0); + } + + reqs_.front()->read_size_ += parser_.get_consumed(); + + if (--reqs_.front()->expected_responses_ == 0) { + // Done with this request. + reqs_.front()->proceed(); + reqs_.pop_front(); + } + + return on_finish_parsing(parse_result::resp); + } + + void reset() + { + write_buffer_.clear(); + read_buffer_.clear(); + parser_.reset(); + on_push_ = false; + cancel_run_called_ = false; + } + + asio::ssl::context ctx_; + std::unique_ptr stream_; + + // Notice we use a timer to simulate a condition-variable. It is + // also more suitable than a channel and the notify operation does + // not suspend. + timer_type writer_timer_; + receive_channel_type receive_channel_; + resolver_type resv_; + detail::connector ctor_; + health_checker_type health_checker_; + resp3_handshaker_type handshaker_; + receiver_adapter_type receive_adapter_; + + using dyn_buffer_type = asio::dynamic_string_buffer, std::allocator>; + + config cfg_; + std::string read_buffer_; + dyn_buffer_type dbuf_; + std::string write_buffer_; + reqs_type reqs_; + resp3::parser parser_{}; + bool on_push_ = false; + bool cancel_run_called_ = false; + + usage usage_; }; /** \brief A basic_connection that type erases the executor. diff --git a/include/boost/redis/detail/connection_base.hpp b/include/boost/redis/detail/connection_base.hpp deleted file mode 100644 index 68bad54d..00000000 --- a/include/boost/redis/detail/connection_base.hpp +++ /dev/null @@ -1,1059 +0,0 @@ -/* Copyright (c) 2018-2024 Marcelo Zimbres Silva (mzimbres@gmail.com) - * - * Distributed under the Boost Software License, Version 1.0. (See - * accompanying file LICENSE.txt) - */ - -#ifndef BOOST_REDIS_CONNECTION_BASE_HPP -#define BOOST_REDIS_CONNECTION_BASE_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace boost::redis::detail -{ - -template -std::string_view buffer_view(DynamicBuffer buf) noexcept -{ - char const* start = static_cast(buf.data(0, buf.size()).data()); - return std::string_view{start, std::size(buf)}; -} - -template -class append_some_op { -private: - AsyncReadStream& stream_; - DynamicBuffer buf_; - std::size_t size_ = 0; - std::size_t tmp_ = 0; - asio::coroutine coro_{}; - -public: - append_some_op(AsyncReadStream& stream, DynamicBuffer buf, std::size_t size) - : stream_ {stream} - , buf_ {std::move(buf)} - , size_{size} - { } - - template - void operator()( Self& self - , system::error_code ec = {} - , std::size_t n = 0) - { - BOOST_ASIO_CORO_REENTER (coro_) - { - tmp_ = buf_.size(); - buf_.grow(size_); - - BOOST_ASIO_CORO_YIELD - stream_.async_read_some(buf_.data(tmp_, size_), std::move(self)); - if (ec) { - self.complete(ec, 0); - return; - } - - buf_.shrink(buf_.size() - tmp_ - n); - self.complete({}, n); - } - } -}; - -template -auto -async_append_some( - AsyncReadStream& stream, - DynamicBuffer buffer, - std::size_t size, - CompletionToken&& token) -{ - return asio::async_compose - < CompletionToken - , void(system::error_code, std::size_t) - >(append_some_op {stream, buffer, size}, token, stream); -} - -template -struct exec_op { - using req_info_type = typename Conn::req_info; - using adapter_type = typename Conn::adapter_type; - - Conn* conn_ = nullptr; - std::shared_ptr info_ = nullptr; - asio::coroutine coro{}; - - template - void operator()(Self& self , system::error_code = {}, std::size_t = 0) - { - BOOST_ASIO_CORO_REENTER (coro) - { - // Check whether the user wants to wait for the connection to - // be stablished. - if (info_->req_->get_config().cancel_if_not_connected && !conn_->is_open()) { - BOOST_ASIO_CORO_YIELD - asio::dispatch( - asio::get_associated_immediate_executor(self, self.get_io_executor()), - std::move(self)); - return self.complete(error::not_connected, 0); - } - - conn_->add_request_info(info_); - -EXEC_OP_WAIT: - BOOST_ASIO_CORO_YIELD - info_->async_wait(std::move(self)); - - if (info_->ec_) { - self.complete(info_->ec_, 0); - return; - } - - if (info_->stop_requested()) { - // Don't have to call remove_request as it has already - // been by cancel(exec). - return self.complete(asio::error::operation_aborted, 0); - } - - if (is_cancelled(self)) { - if (!info_->is_waiting()) { - using c_t = asio::cancellation_type; - auto const c = self.get_cancellation_state().cancelled(); - if ((c & c_t::terminal) != c_t::none) { - // Cancellation requires closing the connection - // otherwise it stays in inconsistent state. - conn_->cancel(operation::run); - return self.complete(asio::error::operation_aborted, 0); - } else { - // Can't implement other cancelation types, ignoring. - self.get_cancellation_state().clear(); - - // TODO: Find out a better way to ignore - // cancelation. - goto EXEC_OP_WAIT; - } - } else { - // Cancelation can be honored. - conn_->remove_request(info_); - self.complete(asio::error::operation_aborted, 0); - return; - } - } - - self.complete(info_->ec_, info_->read_size_); - } - } -}; - -template -struct writer_op { - Conn* conn_; - Logger logger_; - asio::coroutine coro{}; - - template - void operator()( Self& self - , system::error_code ec = {} - , std::size_t n = 0) - { - ignore_unused(n); - - BOOST_ASIO_CORO_REENTER (coro) for (;;) - { - while (conn_->coalesce_requests()) { - if (conn_->use_ssl()) - BOOST_ASIO_CORO_YIELD asio::async_write(conn_->next_layer(), asio::buffer(conn_->write_buffer_), std::move(self)); - else - BOOST_ASIO_CORO_YIELD asio::async_write(conn_->next_layer().next_layer(), asio::buffer(conn_->write_buffer_), std::move(self)); - - logger_.on_write(ec, conn_->write_buffer_); - - if (ec) { - logger_.trace("writer_op (1)", ec); - conn_->cancel(operation::run); - self.complete(ec); - return; - } - - conn_->on_write(); - - // A socket.close() may have been called while a - // successful write might had already been queued, so we - // have to check here before proceeding. - if (!conn_->is_open()) { - logger_.trace("writer_op (2): connection is closed."); - self.complete({}); - return; - } - } - - BOOST_ASIO_CORO_YIELD - conn_->writer_timer_.async_wait(std::move(self)); - if (!conn_->is_open()) { - logger_.trace("writer_op (3): connection is closed."); - // Notice this is not an error of the op, stoping was - // requested from the outside, so we complete with - // success. - self.complete({}); - return; - } - } - } -}; - -template -struct reader_op { - using parse_result = typename Conn::parse_result; - using parse_ret_type = typename Conn::parse_ret_type; - Conn* conn_; - Logger logger_; - parse_ret_type res_{parse_result::resp, 0}; - asio::coroutine coro{}; - - template - void operator()( Self& self - , system::error_code ec = {} - , std::size_t n = 0) - { - ignore_unused(n); - - BOOST_ASIO_CORO_REENTER (coro) for (;;) - { - // Appends some data to the buffer if necessary. - if ((res_.first == parse_result::needs_more) || std::empty(conn_->read_buffer_)) { - if (conn_->use_ssl()) { - BOOST_ASIO_CORO_YIELD - async_append_some( - conn_->next_layer(), - conn_->dbuf_, - conn_->get_suggested_buffer_growth(), - std::move(self)); - } else { - BOOST_ASIO_CORO_YIELD - async_append_some( - conn_->next_layer().next_layer(), - conn_->dbuf_, - conn_->get_suggested_buffer_growth(), - std::move(self)); - } - - logger_.on_read(ec, n); - - // The connection is not viable after an error. - if (ec) { - logger_.trace("reader_op (1)", ec); - conn_->cancel(operation::run); - self.complete(ec); - return; - } - - // Somebody might have canceled implicitly or explicitly - // while we were suspended and after queueing so we have to - // check. - if (!conn_->is_open()) { - logger_.trace("reader_op (2): connection is closed."); - self.complete(ec); - return; - } - } - - res_ = conn_->on_read(buffer_view(conn_->dbuf_), ec); - if (ec) { - logger_.trace("reader_op (3)", ec); - conn_->cancel(operation::run); - self.complete(ec); - return; - } - - if (res_.first == parse_result::push) { - if (!conn_->receive_channel_.try_send(ec, res_.second)) { - BOOST_ASIO_CORO_YIELD - conn_->receive_channel_.async_send(ec, res_.second, std::move(self)); - } - - if (ec) { - logger_.trace("reader_op (4)", ec); - conn_->cancel(operation::run); - self.complete(ec); - return; - } - - if (!conn_->is_open()) { - logger_.trace("reader_op (5): connection is closed."); - self.complete(asio::error::operation_aborted); - return; - } - - } - } - } -}; - -template -class run_op { -private: - Conn* conn_ = nullptr; - Logger logger_; - asio::coroutine coro_{}; - - using order_t = std::array; - -public: - run_op(Conn* conn, Logger l) - : conn_{conn} - , logger_{l} - {} - - template - void operator()( Self& self - , order_t order = {} - , system::error_code ec0 = {} - , system::error_code ec1 = {} - , system::error_code ec2 = {} - , system::error_code ec3 = {} - , system::error_code ec4 = {}) - { - BOOST_ASIO_CORO_REENTER (coro_) for (;;) - { - BOOST_ASIO_CORO_YIELD - conn_->resv_.async_resolve(asio::prepend(std::move(self), order_t {})); - - logger_.on_resolve(ec0, conn_->resv_.results()); - - if (ec0) { - self.complete(ec0); - return; - } - - BOOST_ASIO_CORO_YIELD - conn_->ctor_.async_connect( - conn_->next_layer().next_layer(), - conn_->resv_.results(), - asio::prepend(std::move(self), order_t {})); - - logger_.on_connect(ec0, conn_->ctor_.endpoint()); - - if (ec0) { - self.complete(ec0); - return; - } - - if (conn_->use_ssl()) { - BOOST_ASIO_CORO_YIELD - conn_->next_layer().async_handshake( - asio::ssl::stream_base::client, - asio::prepend( - asio::cancel_after( - conn_->cfg_.ssl_handshake_timeout, - std::move(self) - ), - order_t {} - ) - ); - - logger_.on_ssl_handshake(ec0); - - if (ec0) { - self.complete(ec0); - return; - } - } - - conn_->reset(); - - // Note: Oder is important here because the writer might - // trigger an async_write before the async_hello thereby - // causing an authentication problem. - BOOST_ASIO_CORO_YIELD - asio::experimental::make_parallel_group( - [this](auto token) { return conn_->handshaker_.async_hello(*conn_, logger_, token); }, - [this](auto token) { return conn_->health_checker_.async_ping(*conn_, logger_, token); }, - [this](auto token) { return conn_->health_checker_.async_check_timeout(*conn_, logger_, token);}, - [this](auto token) { return conn_->reader(logger_, token);}, - [this](auto token) { return conn_->writer(logger_, token);} - ).async_wait( - asio::experimental::wait_for_one_error(), - std::move(self)); - - if (order[0] == 0 && !!ec0) { - self.complete(ec0); - return; - } - - if (order[0] == 2 && ec2 == error::pong_timeout) { - self.complete(ec1); - return; - } - - // The receive operation must be cancelled because channel - // subscription does not survive a reconnection but requires - // re-subscription. - conn_->cancel(operation::receive); - - if (!conn_->will_reconnect()) { - conn_->cancel(operation::reconnection); - self.complete(ec3); - return; - } - - // It is safe to use the writer timer here because we are not - // connected. - conn_->writer_timer_.expires_after(conn_->cfg_.reconnect_wait_interval); - - BOOST_ASIO_CORO_YIELD - conn_->writer_timer_.async_wait(asio::prepend(std::move(self), order_t {})); - if (ec0) { - self.complete(ec0); - return; - } - - if (!conn_->will_reconnect()) { - self.complete(asio::error::operation_aborted); - return; - } - - conn_->reset_stream(); - } - } -}; - - -/** @brief Base class for high level Redis asynchronous connections. - * @ingroup high-level-api - * - * @tparam Executor The executor type. - * - */ -template -class connection_base { -public: - /// Executor type - using executor_type = Executor; - - /// Type of the next layer - using next_layer_type = asio::ssl::stream>; - - using clock_type = std::chrono::steady_clock; - using clock_traits_type = asio::wait_traits; - using timer_type = asio::basic_waitable_timer; - - using this_type = connection_base; - - /// Constructs from an executor. - connection_base( - executor_type ex, - asio::ssl::context ctx, - std::size_t max_read_size) - : ctx_{std::move(ctx)} - , stream_{std::make_unique(ex, ctx_)} - , writer_timer_{ex} - , receive_channel_{ex, 256} - , resv_{ex} - , health_checker_{ex} - , dbuf_{read_buffer_, max_read_size} - { - set_receive_response(ignore); - writer_timer_.expires_at((std::chrono::steady_clock::time_point::max)()); - } - - /// Returns the ssl context. - auto const& get_ssl_context() const noexcept - { return ctx_;} - - /// Resets the underlying stream. - void reset_stream() - { - stream_ = std::make_unique(writer_timer_.get_executor(), ctx_); - } - - /// Returns a reference to the next layer. - auto& next_layer() noexcept { return *stream_; } - - /// Returns a const reference to the next layer. - auto const& next_layer() const noexcept { return *stream_; } - - /// Returns the associated executor. - auto get_executor() {return writer_timer_.get_executor();} - - /// Cancels specific operations. - void cancel(operation op) - { - switch (op) { - case operation::resolve: - resv_.cancel(); - break; - case operation::exec: - cancel_unwritten_requests(); - break; - case operation::reconnection: - cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); - break; - case operation::run: - cancel_run(); - break; - case operation::receive: - receive_channel_.cancel(); - break; - case operation::health_check: - health_checker_.cancel(); - break; - case operation::all: - resv_.cancel(); - cfg_.reconnect_wait_interval = std::chrono::seconds::zero(); - health_checker_.cancel(); - cancel_run(); // run - receive_channel_.cancel(); // receive - cancel_unwritten_requests(); // exec - break; - default: /* ignore */; - } - } - - template - auto async_exec(request const& req, any_adapter&& adapter, CompletionToken&& token) - { - auto& adapter_impl = adapter.impl_; - BOOST_ASSERT_MSG(req.get_expected_responses() <= adapter_impl.supported_response_size, "Request and response have incompatible sizes."); - - auto info = std::make_shared(req, std::move(adapter_impl.adapt_fn), get_executor()); - - return asio::async_compose - < CompletionToken - , void(system::error_code, std::size_t) - >(exec_op{this, info}, token, writer_timer_); - } - - template - [[deprecated("Set the response with set_receive_response and use the other overload.")]] - auto async_receive(Response& response, CompletionToken token) - { - set_receive_response(response); - return receive_channel_.async_receive(std::move(token)); - } - - template - auto async_receive(CompletionToken token) - { return receive_channel_.async_receive(std::move(token)); } - - std::size_t receive(system::error_code& ec) - { - std::size_t size = 0; - - auto f = [&](system::error_code const& ec2, std::size_t n) - { - ec = ec2; - size = n; - }; - - auto const res = receive_channel_.try_receive(f); - if (ec) - return 0; - - if (!res) - ec = error::sync_receive_push_failed; - - return size; - } - - template - auto async_run(config const& cfg, Logger l, CompletionToken token) - { - cfg_ = cfg; - resv_.set_config(cfg); - ctor_.set_config(cfg); - health_checker_.set_config(cfg); - handshaker_.set_config(cfg); - l.set_prefix(cfg.log_prefix); - - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(run_op{this, l}, token, writer_timer_); - } - - template - void set_receive_response(Response& response) - { - using namespace boost::redis::adapter; - auto g = boost_redis_adapt(response); - receive_adapter_ = adapter::detail::make_adapter_wrapper(g); - } - - usage get_usage() const noexcept - { return usage_; } - - auto run_is_canceled() const noexcept - { return cancel_run_called_; } - - bool will_reconnect() const noexcept - { return cfg_.reconnect_wait_interval != std::chrono::seconds::zero();} - -private: - using receive_channel_type = asio::experimental::channel; - using resolver_type = resolver; - using health_checker_type = health_checker; - using resp3_handshaker_type = resp3_handshaker; - using adapter_type = std::function const&, system::error_code&)>; - using receiver_adapter_type = std::function const&, system::error_code&)>; - using exec_notifier_type = receive_channel_type; - - auto use_ssl() const noexcept - { return cfg_.use_ssl;} - - auto cancel_on_conn_lost() -> std::size_t - { - // Must return false if the request should be removed. - auto cond = [](auto const& ptr) - { - BOOST_ASSERT(ptr != nullptr); - - if (ptr->is_waiting()) { - return !ptr->req_->get_config().cancel_on_connection_lost; - } else { - return !ptr->req_->get_config().cancel_if_unresponded; - } - }; - - auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), cond); - - auto const ret = std::distance(point, std::end(reqs_)); - - std::for_each(point, std::end(reqs_), [](auto const& ptr) { - ptr->stop(); - }); - - reqs_.erase(point, std::end(reqs_)); - - std::for_each(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { - return ptr->mark_waiting(); - }); - - return ret; - } - - auto cancel_unwritten_requests() -> std::size_t - { - auto f = [](auto const& ptr) - { - BOOST_ASSERT(ptr != nullptr); - return !ptr->is_waiting(); - }; - - auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), f); - - auto const ret = std::distance(point, std::end(reqs_)); - - std::for_each(point, std::end(reqs_), [](auto const& ptr) { - ptr->stop(); - }); - - reqs_.erase(point, std::end(reqs_)); - return ret; - } - - void cancel_run() - { - // Protects the code below from being called more than - // once, see https://github.com/boostorg/redis/issues/181 - if (std::exchange(cancel_run_called_, true)) { - return; - } - - close(); - writer_timer_.cancel(); - receive_channel_.cancel(); - cancel_on_conn_lost(); - } - - void on_write() - { - // We have to clear the payload right after writing it to use it - // as a flag that informs there is no ongoing write. - write_buffer_.clear(); - - // Notice this must come before the for-each below. - cancel_push_requests(); - - // There is small optimization possible here: traverse only the - // partition of unwritten requests instead of them all. - std::for_each(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { - BOOST_ASSERT_MSG(ptr != nullptr, "Expects non-null pointer."); - if (ptr->is_staged()) { - ptr->mark_written(); - } - }); - } - - struct req_info { - public: - using node_type = resp3::basic_node; - using wrapped_adapter_type = std::function; - - explicit req_info(request const& req, adapter_type adapter, executor_type ex) - : notifier_{ex, 1} - , req_{&req} - , adapter_{} - , expected_responses_{req.get_expected_responses()} - , status_{status::waiting} - , ec_{{}} - , read_size_{0} - { - adapter_ = [this, adapter](node_type const& nd, system::error_code& ec) - { - auto const i = req_->get_expected_responses() - expected_responses_; - adapter(i, nd, ec); - }; - } - - auto proceed() - { - notifier_.try_send(std::error_code{}, 0); - } - - void stop() - { - notifier_.close(); - } - - [[nodiscard]] auto is_waiting() const noexcept - { return status_ == status::waiting; } - - [[nodiscard]] auto is_written() const noexcept - { return status_ == status::written; } - - [[nodiscard]] auto is_staged() const noexcept - { return status_ == status::staged; } - - void mark_written() noexcept - { status_ = status::written; } - - void mark_staged() noexcept - { status_ = status::staged; } - - void mark_waiting() noexcept - { status_ = status::waiting; } - - [[nodiscard]] auto stop_requested() const noexcept - { return !notifier_.is_open();} - - template - auto async_wait(CompletionToken token) - { - return notifier_.async_receive(std::move(token)); - } - - //private: - enum class status - { waiting - , staged - , written - }; - - exec_notifier_type notifier_; - request const* req_; - wrapped_adapter_type adapter_; - - // Contains the number of commands that haven't been read yet. - std::size_t expected_responses_; - status status_; - - system::error_code ec_; - std::size_t read_size_; - }; - - void remove_request(std::shared_ptr const& info) - { - reqs_.erase(std::remove(std::begin(reqs_), std::end(reqs_), info)); - } - - using reqs_type = std::deque>; - - template friend struct reader_op; - template friend struct writer_op; - template friend struct exec_op; - template friend class run_op; - - void cancel_push_requests() - { - auto point = std::stable_partition(std::begin(reqs_), std::end(reqs_), [](auto const& ptr) { - return !(ptr->is_staged() && ptr->req_->get_expected_responses() == 0); - }); - - std::for_each(point, std::end(reqs_), [](auto const& ptr) { - ptr->proceed(); - }); - - reqs_.erase(point, std::end(reqs_)); - } - - [[nodiscard]] bool is_writing() const noexcept - { - return !write_buffer_.empty(); - } - - void add_request_info(std::shared_ptr const& info) - { - reqs_.push_back(info); - - if (info->req_->has_hello_priority()) { - auto rend = std::partition_point(std::rbegin(reqs_), std::rend(reqs_), [](auto const& e) { - return e->is_waiting(); - }); - - std::rotate(std::rbegin(reqs_), std::rbegin(reqs_) + 1, rend); - } - - if (is_open() && !is_writing()) - writer_timer_.cancel(); - } - - template - auto reader(Logger l, CompletionToken&& token) - { - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(reader_op{this, l}, token, writer_timer_); - } - - template - auto writer(Logger l, CompletionToken&& token) - { - return asio::async_compose - < CompletionToken - , void(system::error_code) - >(writer_op{this, l}, token, writer_timer_); - } - - [[nodiscard]] bool coalesce_requests() - { - // Coalesces the requests and marks them staged. After a - // successful write staged requests will be marked as written. - auto const point = std::partition_point(std::cbegin(reqs_), std::cend(reqs_), [](auto const& ri) { - return !ri->is_waiting(); - }); - - std::for_each(point, std::cend(reqs_), [this](auto const& ri) { - // Stage the request. - write_buffer_ += ri->req_->payload(); - ri->mark_staged(); - usage_.commands_sent += ri->expected_responses_; - }); - - usage_.bytes_sent += std::size(write_buffer_); - - return point != std::cend(reqs_); - } - - bool is_waiting_response() const noexcept - { - if (std::empty(reqs_)) - return false; - - // Under load and on low-latency networks we might start - // receiving responses before the write operation completed and - // the request is still maked as staged and not written. See - // https://github.com/boostorg/redis/issues/170 - return !reqs_.front()->is_waiting(); - } - - void close() - { - if (stream_->next_layer().is_open()) { - system::error_code ec; - stream_->next_layer().close(ec); - } - } - - auto is_open() const noexcept { return stream_->next_layer().is_open(); } - auto& lowest_layer() noexcept { return stream_->lowest_layer(); } - - auto is_next_push() - { - BOOST_ASSERT(!read_buffer_.empty()); - - // Useful links to understand the heuristics below. - // - // - https://github.com/redis/redis/issues/11784 - // - https://github.com/redis/redis/issues/6426 - // - https://github.com/boostorg/redis/issues/170 - - // The message's resp3 type is a push. - if (resp3::to_type(read_buffer_.front()) == resp3::type::push) - return true; - - // This is non-push type and the requests queue is empty. I have - // noticed this is possible, for example with -MISCONF. I don't - // know why they are not sent with a push type so we can - // distinguish them from responses to commands. If we are lucky - // enough to receive them when the command queue is empty they - // can be treated as server pushes, otherwise it is impossible - // to handle them properly - if (reqs_.empty()) - return true; - - // The request does not expect any response but we got one. This - // may happen if for example, subscribe with wrong syntax. - if (reqs_.front()->expected_responses_ == 0) - return true; - - // Added to deal with MONITOR and also to fix PR170 which - // happens under load and on low-latency networks, where we - // might start receiving responses before the write operation - // completed and the request is still maked as staged and not - // written. - return reqs_.front()->is_waiting(); - } - - auto get_suggested_buffer_growth() const noexcept - { - return parser_.get_suggested_buffer_growth(4096); - } - - enum class parse_result { needs_more, push, resp }; - - using parse_ret_type = std::pair; - - parse_ret_type on_finish_parsing(parse_result t) - { - if (t == parse_result::push) { - usage_.pushes_received += 1; - usage_.push_bytes_received += parser_.get_consumed(); - } else { - usage_.responses_received += 1; - usage_.response_bytes_received += parser_.get_consumed(); - } - - on_push_ = false; - dbuf_.consume(parser_.get_consumed()); - auto const res = std::make_pair(t, parser_.get_consumed()); - parser_.reset(); - return res; - } - - parse_ret_type on_read(std::string_view data, system::error_code& ec) - { - // We arrive here in two states: - // - // 1. While we are parsing a message. In this case we - // don't want to determine the type of the message in the - // buffer (i.e. response vs push) but leave it untouched - // until the parsing of a complete message ends. - // - // 2. On a new message, in which case we have to determine - // whether the next messag is a push or a response. - // - if (!on_push_) // Prepare for new message. - on_push_ = is_next_push(); - - if (on_push_) { - if (!resp3::parse(parser_, data, receive_adapter_, ec)) - return std::make_pair(parse_result::needs_more, 0); - - if (ec) - return std::make_pair(parse_result::push, 0); - - return on_finish_parsing(parse_result::push); - } - - BOOST_ASSERT_MSG(is_waiting_response(), "Not waiting for a response (using MONITOR command perhaps?)"); - BOOST_ASSERT(!reqs_.empty()); - BOOST_ASSERT(reqs_.front() != nullptr); - BOOST_ASSERT(reqs_.front()->expected_responses_ != 0); - - if (!resp3::parse(parser_, data, reqs_.front()->adapter_, ec)) - return std::make_pair(parse_result::needs_more, 0); - - if (ec) { - reqs_.front()->ec_ = ec; - reqs_.front()->proceed(); - return std::make_pair(parse_result::resp, 0); - } - - reqs_.front()->read_size_ += parser_.get_consumed(); - - if (--reqs_.front()->expected_responses_ == 0) { - // Done with this request. - reqs_.front()->proceed(); - reqs_.pop_front(); - } - - return on_finish_parsing(parse_result::resp); - } - - void reset() - { - write_buffer_.clear(); - read_buffer_.clear(); - parser_.reset(); - on_push_ = false; - cancel_run_called_ = false; - } - - asio::ssl::context ctx_; - std::unique_ptr stream_; - - // Notice we use a timer to simulate a condition-variable. It is - // also more suitable than a channel and the notify operation does - // not suspend. - timer_type writer_timer_; - receive_channel_type receive_channel_; - resolver_type resv_; - connector ctor_; - health_checker_type health_checker_; - resp3_handshaker_type handshaker_; - receiver_adapter_type receive_adapter_; - - using dyn_buffer_type = asio::dynamic_string_buffer, std::allocator>; - - config cfg_; - std::string read_buffer_; - dyn_buffer_type dbuf_; - std::string write_buffer_; - reqs_type reqs_; - resp3::parser parser_{}; - bool on_push_ = false; - bool cancel_run_called_ = false; - - usage usage_; -}; - -} // boost::redis::detail - -#endif // BOOST_REDIS_CONNECTION_BASE_HPP diff --git a/test/test_issue_181.cpp b/test/test_issue_181.cpp index 123f902d..73871c88 100644 --- a/test/test_issue_181.cpp +++ b/test/test_issue_181.cpp @@ -31,12 +31,12 @@ using namespace std::chrono_literals; BOOST_AUTO_TEST_CASE(issue_181) { - using connection_base = boost::redis::detail::connection_base; + using basic_connection = boost::redis::basic_connection; auto const level = boost::redis::logger::level::debug; net::io_context ioc; auto ctx = net::ssl::context{net::ssl::context::tlsv12_client}; - connection_base conn{ioc.get_executor(), std::move(ctx), 1000000}; + basic_connection conn{ioc.get_executor(), std::move(ctx), 1000000}; net::steady_timer timer{ioc}; timer.expires_after(std::chrono::seconds{1});