diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index c5f7db0e..2a1dacd0 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -9,6 +9,7 @@ on: pull_request: branches: - master + - streaming env: # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) diff --git a/include/libnuraft/asio_service_options.hxx b/include/libnuraft/asio_service_options.hxx index b1ef5533..c78ec4f9 100644 --- a/include/libnuraft/asio_service_options.hxx +++ b/include/libnuraft/asio_service_options.hxx @@ -125,6 +125,7 @@ struct asio_service_options { , crc_on_entire_message_(false) , crc_on_payload_(false) , corrupted_msg_handler_(nullptr) + , streaming_mode_(false) {} /** @@ -276,6 +277,13 @@ struct asio_service_options { */ std::function< void( std::shared_ptr, std::shared_ptr ) > corrupted_msg_handler_; + + /** + * If `true`, NuRaft will use streaming mode, which allows it to send + * subsequent requests without waiting for the response to previous requests. + * The order of responses will be identical to the order of requests. + */ + bool streaming_mode_; }; } diff --git a/scripts/test/runtests.sh b/scripts/test/runtests.sh index ddf60d2c..b730b61e 100755 --- a/scripts/test/runtests.sh +++ b/scripts/test/runtests.sh @@ -9,3 +9,4 @@ set -e ./tests/raft_server_test --abort-on-failure ./tests/failure_test --abort-on-failure ./tests/asio_service_test --abort-on-failure +./tests/asio_service_stream_test --abort-on-failure diff --git a/src/asio_service.cxx b/src/asio_service.cxx index 4decabb8..a2674906 100644 --- a/src/asio_service.cxx +++ b/src/asio_service.cxx @@ -150,6 +150,19 @@ asio_service::meta_cb_params req_to_params(req_msg* req, resp_msg* resp) { // === ASIO Abstraction === // (to switch SSL <-> unsecure on-the-fly) +struct pending_req_pkg { + pending_req_pkg(ptr& req, + rpc_handler& when_done, + uint64_t timeout_ms = 0) + : req_(req) + , when_done_(when_done) + , timeout_ms_(timeout_ms) + {} + ptr req_; + rpc_handler when_done_; + uint64_t timeout_ms_; +}; + class aa { public: template @@ -970,6 +983,8 @@ class asio_rpc_listener const ERROR_CODE& err) { if (!err) { + asio::ip::tcp::no_delay option(true); + session->socket().set_option(option); p_in("receive a incoming rpc connection"); session->prepare_handshake(); @@ -1115,32 +1130,68 @@ class asio_rpc_client ( ssl_enabled_ ? "enabled" : "disabled" ) ); } abandoned_= true; - ptr rsp; - ptr except - ( cs_new - ( lstrfmt("timeout while connecting to %s") - .fmt(host_.c_str()), - req ) ); - when_done(rsp, except); + std::string err_msg = + lstrfmt("timeout while connecting to %s").fmt(host_.c_str()); + handle_error(req, err_msg, when_done); return; } - send(req, when_done, send_timeout_ms); + register_req_send(req, when_done, send_timeout_ms); } virtual void send(ptr& req, rpc_handler& when_done, uint64_t send_timeout_ms = 0) __override__ { + if (impl_->get_options().streaming_mode_) { + pre_send(req, when_done, send_timeout_ms); + } else { + register_req_send(req, when_done, send_timeout_ms); + } + } + + void handle_error(ptr req, + std::string& err_msg, + rpc_handler when_done) { + close_socket(err_msg); + + // In streaming mode, all `when_done` will be invoked in `close_socket()`. + // Otherwise, `close_socket()` will do nothing, hence `when_done` + // should directly be invoked here. + if (!impl_->get_options().streaming_mode_) { + ptr resp; + ptr except(cs_new(err_msg, req)); + when_done(resp, except); + } + } + + void pre_send(ptr& req, + rpc_handler& when_done, + uint64_t send_timeout_ms) { + bool immediate_action_needed = false; + { + auto_lock(pending_write_reqs_lock_); + pending_write_reqs_.push_back( + cs_new(req, when_done, send_timeout_ms)); + immediate_action_needed = (pending_write_reqs_.size() == 1); + p_db("start to send msg to peer %d, start_log_idx: %ld, size: %ld, " + "pending write reqs: %ld", req->get_dst(), req->get_last_log_idx(), + req->log_entries().size(), pending_write_reqs_.size()); + } + + if (immediate_action_needed) { + register_req_send(req, when_done, send_timeout_ms); + } + } + + void register_req_send(ptr& req, + rpc_handler& when_done, + uint64_t send_timeout_ms) { if (abandoned_) { p_er( "client %p to %s:%s is already stale (SSL %s)", this, host_.c_str(), port_.c_str(), ( ssl_enabled_ ? "enabled" : "disabled" ) ); - ptr rsp; - ptr except - ( cs_new - ( lstrfmt("abandoned client to %s").fmt(host_.c_str()), - req ) ); - when_done(rsp, except); + std::string err_msg = lstrfmt("abandoned client to %s").fmt(host_.c_str()); + handle_error(req, err_msg, when_done); return; } @@ -1200,17 +1251,14 @@ class asio_rpc_client execute_resolver(self, req, resolved_host, resolved_port, when_done, send_timeout_ms); } else { - ptr rsp; - ptr except - ( cs_new - ( lstrfmt("failed to resolve host %s by given " - "custom resolver " - "due to error %d, %s") - .fmt( host_.c_str(), - err.value(), - err.message().c_str() ), - req ) ); - when_done(rsp, except); + std::string err_msg = lstrfmt("failed to resolve " + "host %s by given " + "custom resolver " + "due to error %d, %s") + .fmt( host_.c_str(), + err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); } } ); } else { @@ -1433,21 +1481,21 @@ class asio_rpc_client std::placeholders::_1 ) ); } } else { - ptr rsp; - ptr except - ( cs_new - ( lstrfmt("failed to resolve host %s " - "due to error %d, %s") - .fmt( host.c_str(), - err.value(), - err.message().c_str() ), - req ) ); - when_done(rsp, except); + std::string err_msg = lstrfmt("failed to resolve host %s " + "due to error %d, %s") + .fmt( host.c_str(), + err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); } } ); } void set_busy_flag(bool to) { + if (impl_->get_options().streaming_mode_) { + return; + } + if (to == true) { bool exp = false; if (!socket_busy_.compare_exchange_strong(exp, true)) { @@ -1465,7 +1513,7 @@ class asio_rpc_client } } - void close_socket() { + void close_socket(std::string err_msg = std::string()) { // Do nothing, // early closing socket before destroying this instance // may cause problem, especially when SSL is enabled. @@ -1481,6 +1529,37 @@ class asio_rpc_client } } #endif + if (!impl_->get_options().streaming_mode_) { + return; + } + + // In streaming mode, it should invoke all `when_done` callbacks, + // in chronological order. + + // clear write queue and read queue here + // from oldest to latest, read queue first + cancel_pending_requests(pending_read_reqs_, pending_read_reqs_lock_, err_msg); + cancel_pending_requests(pending_write_reqs_, pending_write_reqs_lock_, err_msg); + } + + void cancel_pending_requests(std::list>& reqs_list, + std::mutex& lock, + std::string& err_msg) { + std::list> reqs_to_cancel; + { + auto_lock(lock); + reqs_to_cancel = std::move(reqs_list); + } + + for (auto& pkg: reqs_to_cancel) { + ptr rsp; + if (err_msg.empty()) { + err_msg = lstrfmt("socket to host %s is closed").fmt( host_.c_str() ); + } + ptr except( cs_new( err_msg, pkg->req_) ); + pkg->when_done_(rsp, except); + } + reqs_to_cancel.clear(); } void cancel_socket(const ERROR_CODE& err) { @@ -1504,6 +1583,8 @@ class asio_rpc_client { operation_timer_.cancel(); if (!err) { + asio::ip::tcp::no_delay option(true); + socket().set_option(option); p_in( "%p connected to %s:%s (as a client)", this, host_.c_str(), port_.c_str() ); if (ssl_enabled_) { @@ -1520,19 +1601,17 @@ class asio_rpc_client std::placeholders::_1 ) ); #endif } else { - this->send(req, when_done, send_timeout_ms); + this->register_req_send(req, when_done, send_timeout_ms); } } else { abandoned_ = true; - ptr rsp; - ptr except - ( cs_new - ( sstrfmt("failed to connect to peer %d, %s:%s, error %d, %s") - .fmt( req->get_dst(), host_.c_str(), - port_.c_str(), err.value(), err.message().c_str() ), - req ) ); - when_done(rsp, except); + std::string err_msg = sstrfmt("failed to connect to " + "peer %d, %s:%s, error %d, %s") + .fmt( req->get_dst(), host_.c_str(), + port_.c_str(), err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); } } @@ -1547,7 +1626,7 @@ class asio_rpc_client p_in( "handshake with %s:%s succeeded (as a client)", host_.c_str(), port_.c_str() ); ssl_ready_ = true; - this->send(req, when_done, send_timeout_ms); + this->register_req_send(req, when_done, send_timeout_ms); } else { abandoned_ = true; @@ -1556,15 +1635,12 @@ class asio_rpc_client err.message().c_str() ); // Immediately stop. - ptr resp; - ptr except - ( cs_new - ( sstrfmt("failed SSL handshake with peer %d, %s:%s, " - "error %d, %s") - .fmt( req->get_dst(), host_.c_str(), - port_.c_str(), err.value(), err.message().c_str() ), - req ) ); - when_done(resp, except); + std::string err_msg = sstrfmt("failed SSL handshake with peer %d, %s:%s, " + "error %d, %s") + .fmt( req->get_dst(), host_.c_str(), + port_.c_str(), err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); } } @@ -1576,33 +1652,21 @@ class asio_rpc_client { // Now we can safely free the `req_buf`. (void)buf; - ptr self(this->shared_from_this()); if (!err) { - // read a response - ptr resp_buf(buffer::alloc(RPC_RESP_HEADER_SIZE)); - aa::read( ssl_enabled_, ssl_socket_, socket_, - asio::buffer(resp_buf->data(), resp_buf->size()), - std::bind(&asio_rpc_client::response_read, - self, - req, - when_done, - resp_buf, - std::placeholders::_1, - std::placeholders::_2)); - + if (impl_->get_options().streaming_mode_) { + post_send(req, when_done); + } else { + register_response_read(req, when_done); + } } else { operation_timer_.cancel(); abandoned_ = true; - ptr rsp; - ptr except - ( cs_new - ( sstrfmt( "failed to send request to peer %d, %s:%s, " - "error %d, %s" ) - .fmt( req->get_dst(), host_.c_str(), - port_.c_str(), err.value(), err.message().c_str() ), - req ) ); - close_socket(); - when_done(rsp, except); + std::string err_msg = sstrfmt( "failed to send request to peer %d, %s:%s, " + "error %d, %s" ) + .fmt( req->get_dst(), host_.c_str(), + port_.c_str(), err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); } } @@ -1615,16 +1679,12 @@ class asio_rpc_client ptr self(this->shared_from_this()); if (err) { abandoned_ = true; - ptr rsp; - ptr except - ( cs_new - ( sstrfmt( "failed to read response to peer %d, %s:%s, " - "error %d, %s" ) - .fmt( req->get_dst(), host_.c_str(), - port_.c_str(), err.value(), err.message().c_str() ), - req ) ); - close_socket(); - when_done(rsp, except); + std::string err_msg = sstrfmt( "failed to read response to peer %d, %s:%s, " + "error %d, %s" ) + .fmt( req->get_dst(), host_.c_str(), + port_.c_str(), err.value(), + err.message().c_str() ); + handle_error(req, err_msg, when_done); return; } @@ -1638,16 +1698,12 @@ class asio_rpc_client uint32_t flags = (flags_and_crc >> 32); if (crc_local != crc_buf) { - ptr rsp; - ptr except - ( cs_new - ( sstrfmt( "CRC mismatch in response from peer %d, %s:%s, " - "local calculation %x, from buffer %x") - .fmt( req->get_dst(), host_.c_str(), - port_.c_str(), crc_local, crc_buf ), - req ) ); - close_socket(); - when_done(rsp, except); + std::string err_msg = sstrfmt( "CRC mismatch in response from " + "peer %d, %s:%s, " + "local calculation %x, from buffer %x") + .fmt( req->get_dst(), host_.c_str(), + port_.c_str(), crc_local, crc_buf ); + handle_error(req, err_msg, when_done); return; } @@ -1692,6 +1748,7 @@ class asio_rpc_client set_busy_flag(false); ptr except; when_done(rsp, except); + post_read(); } } @@ -1715,6 +1772,7 @@ class asio_rpc_client set_busy_flag(false); ptr except; when_done(rsp, except); + post_read(); return; } @@ -1777,6 +1835,7 @@ class asio_rpc_client set_busy_flag(false); ptr except; when_done(rsp, except); + post_read(); } bool handle_custom_resp_meta(ptr& req, @@ -1789,21 +1848,90 @@ class asio_rpc_client if (!meta_ok) { // Callback function returns false, should return failure. - ptr rsp; - ptr except - ( cs_new - ( sstrfmt( "response meta verification failed: " - "from peer %d, %s:%s") - .fmt( req->get_dst(), host_.c_str(), - port_.c_str() ), - req ) ); - close_socket(); - when_done(rsp, except); + std::string err_msg = sstrfmt( "response meta verification failed: " + "from peer %d, %s:%s") + .fmt( req->get_dst(), host_.c_str(), + port_.c_str() ); + handle_error(req, err_msg, when_done); return false; } return true; } + void register_response_read( ptr& req, + rpc_handler& when_done ) + { + ptr self(this->shared_from_this()); + ptr resp_buf(buffer::alloc(RPC_RESP_HEADER_SIZE)); + aa::read( ssl_enabled_, ssl_socket_, socket_, + asio::buffer(resp_buf->data(), resp_buf->size()), + std::bind( &asio_rpc_client::response_read, + self, + req, + when_done, + resp_buf, + std::placeholders::_1, + std::placeholders::_2 ) ); + } + + void post_send(ptr& req, rpc_handler& when_done) { + // first process read + bool immediate_action_needed = false; + { + auto_lock(pending_read_reqs_lock_); + pending_read_reqs_.push_back(cs_new(req, when_done)); + immediate_action_needed = (pending_read_reqs_.size() == 1); + p_db("msg to peer %d has been write down, start_log_idx: %ld, " + "size: %ld, pending read reqs: %ld", req->get_dst(), + req->get_last_log_idx(), + req->log_entries().size(), pending_read_reqs_.size()); + } + + if (immediate_action_needed) { + register_response_read(req, when_done); + } + + // next process write + ptr next_req_pkg{nullptr}; + { + auto_lock(pending_write_reqs_lock_); + pending_write_reqs_.pop_front(); + if (pending_write_reqs_.size() > 0) { + next_req_pkg = *pending_write_reqs_.begin(); + p_db("trigger next write, start_log_idx: %ld, pending write reqs: %ld", + next_req_pkg->req_->get_last_log_idx(), pending_write_reqs_.size()); + } + } + + if (next_req_pkg) { + register_req_send(next_req_pkg->req_, + next_req_pkg->when_done_, + next_req_pkg->timeout_ms_); + } + } + + void post_read() { + if (!impl_->get_options().streaming_mode_) { + return; + } + + // trigger next read + ptr next_req_pkg{nullptr}; + { + auto_lock(pending_read_reqs_lock_); + pending_read_reqs_.pop_front(); + if (pending_read_reqs_.size() > 0) { + next_req_pkg = *pending_read_reqs_.begin(); + p_db("trigger next read, start_log_idx: %ld, pending read reqs: %ld", + next_req_pkg->req_->get_last_log_idx(), pending_read_reqs_.size()); + } + } + + if (next_req_pkg) { + register_response_read(next_req_pkg->req_, next_req_pkg->when_done_); + } + } + private: asio_service_impl* impl_; asio::ip::tcp::resolver resolver_; @@ -1822,6 +1950,26 @@ class asio_rpc_client uint64_t client_id_; asio::steady_timer operation_timer_; ptr l_; + + /** + * Queue of request which is pending for reading. + */ + std::list> pending_read_reqs_; + + /** + * Lock for `pending_read_reqs_`. + */ + std::mutex pending_read_reqs_lock_; + + /** + * Queue of request which is pending for writing. + */ + std::list> pending_write_reqs_; + + /** + * Lock for `pending_write_reqs_`. + */ + std::mutex pending_write_reqs_lock_; }; } // namespace nuraft diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 39a44a85..0246b7ff 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,6 +32,17 @@ target_link_libraries(asio_service_test ${BUILD_DIR}/${LIBRARY_OUTPUT_NAME} ${LIBRARIES}) +add_executable(asio_service_stream_test + unit/asio_service_stream_test.cxx + ${EXAMPLES_SRC}/logger.cc + ${EXAMPLES_SRC}/in_memory_log_store.cxx) +add_dependencies(asio_service_stream_test + static_lib + build_ssl_key) +target_link_libraries(asio_service_stream_test + ${BUILD_DIR}/${LIBRARY_OUTPUT_NAME} + ${LIBRARIES}) + # === Benchmark === add_executable(raft_bench bench/raft_bench.cxx diff --git a/tests/unit/asio_service_stream_test.cxx b/tests/unit/asio_service_stream_test.cxx new file mode 100644 index 00000000..fd2d6abe --- /dev/null +++ b/tests/unit/asio_service_stream_test.cxx @@ -0,0 +1,240 @@ +/************************************************************************ +Copyright 2017-present eBay Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +**************************************************************************/ + +#include "buffer_serializer.hxx" +#include "in_memory_log_store.hxx" +#include "raft_package_asio.hxx" + +#include "event_awaiter.hxx" +#include "test_common.h" + +#include + +using namespace nuraft; +using namespace raft_functional_common; + +namespace asio_service_stream_test { + const std::string TEST_MSG = "stream-test-msg-str"; + + class stream_msg_handler : public nuraft::msg_handler { + public: + stream_msg_handler(context* ctx, + const init_options& opt, + ptr log_wrapper) + : msg_handler(ctx, opt) + , my_log_wrapper_(log_wrapper) + , streamed_log_index(0) + , msg_mismatch(false) + {} + + ptr process_req(req_msg& req, const req_ext_params& ext_params) { + ptr resp = cs_new(state_->get_term(), + msg_type::append_entries_response, + id_, + req.get_src()); + if (req.get_last_log_idx() == streamed_log_index) { + streamed_log_index++; + resp->accept(streamed_log_index.load()); + ptr buf = req.log_entries().at(0)->get_buf_ptr(); + buf->pos(0); + std::string buf_str = buf->get_str(); + if (buf_str != TEST_MSG) { + SimpleLogger* ll = my_log_wrapper_->getLogger(); + _log_info(ll, "resp str: %s", buf_str.c_str()); + msg_mismatch.store(true); + } + } else { + SimpleLogger* ll = my_log_wrapper_->getLogger(); + _log_info(ll, "req log index not match, req: %ld, current: %ld", + req.get_last_log_idx(), streamed_log_index.load()); + } + return resp; + } + + ptr my_log_wrapper_; + std::atomic streamed_log_index; + std::atomic msg_mismatch; + }; + + class stream_server { + public: + stream_server(int id, int port) + : my_id_(id) + , port_(port) + , next_log_index_(1) + { + init_server(); + } + + void send_req(int count) { + ptr msg = buffer::alloc(TEST_MSG.size() + 1); + msg->put(TEST_MSG); + + TestSuite::Progress pp(count, "sending req"); + + while (count > 0) { + ptr req(cs_new( + 1, msg_type::append_entries_request, 1, my_id_, + 1, sent_log_index_, 1)); + + ptr log(cs_new(0, msg, log_val_type::app_log)); + req->log_entries().push_back(log); + + rpc_handler h = (rpc_handler)std::bind( + &stream_server::handle_result, + this, + req, + std::placeholders::_1, + std::placeholders::_2); + my_client_->send(req, h); + sent_log_index_++; + pp.update(sent_log_index_); + count--; + } + pp.done(); + num_messages_sent_= sent_log_index_; + } + + void handle_result(ptr& req, + ptr& resp, + ptr& err) + { + if (resp->get_next_idx() == get_next_log_index()) { + next_log_index_++; + } else { + SimpleLogger* ll = my_log_wrapper_->getLogger(); + _log_info(ll, "resp log index not match, resp: %ld, current: %ld", + resp->get_next_idx(), get_next_log_index()); + } + if (next_log_index_ == num_messages_sent_ + 1) { + ea.invoke(); + } + } + + bool waiting_for_responses(int timeout_ms = 3000) { + TestSuite::_msg("wait for responses (up to %d ms)\n", timeout_ms); + ea.wait_ms(timeout_ms); + return (next_log_index_ == num_messages_sent_ + 1); + } + + void stop_server() { + if (my_listener_) { + my_listener_->stop(); + my_listener_->shutdown(); + } + + if (asio_svc_) { + asio_svc_->stop(); + size_t count = 0; + while (asio_svc_->get_active_workers() && count < 500) { + // 10ms per tick. + timer_helper::sleep_ms(10); + count++; + } + } + } + + ulong get_resp_log_index() { + return my_msg_handler_->streamed_log_index; + } + + bool is_msg_mismatch() { + return my_msg_handler_->msg_mismatch; + } + + ulong get_next_log_index() { + return next_log_index_; + } + + private: + int my_id_; + int port_; + std::atomic next_log_index_; + ulong sent_log_index_ = 0; + ptr asio_svc_; + ptr my_client_; + ptr my_listener_; + ptr my_log_wrapper_; + ptr my_log_; + ptr my_msg_handler_; + size_t num_messages_sent_ = 0; + EventAwaiter ea; + + void init_server() { + std::string log_file_name = "./srv" + std::to_string(my_id_) + ".log"; + my_log_wrapper_ = cs_new(log_file_name); + my_log_ = my_log_wrapper_; + + // opts + asio_service::options asio_opt; + asio_opt.thread_pool_size_ = 2; + asio_opt.replicate_log_timestamp_ = false; + asio_opt.streaming_mode_ = true; + asio_svc_ = cs_new(asio_opt, my_log_); + + // client + std::string endpoint = "localhost:"+std::to_string(port_); + my_client_ = asio_svc_->create_client(endpoint); + + // server + ptr s_mgr = cs_new(my_id_, endpoint); + ptr sm = cs_new( my_log_wrapper_->getLogger() ); + ptr scheduler = asio_svc_; + ptr rpc_cli_factory = asio_svc_; + my_listener_ = asio_svc_->create_rpc_listener(port_, my_log_); + + raft_params params; + context* ctx( new context( s_mgr, sm, my_listener_, my_log_, + rpc_cli_factory, scheduler, params ) ); + const raft_server::init_options& opt = raft_server::init_options(); + my_msg_handler_ = cs_new(ctx, opt, my_log_wrapper_); + ptr handler = my_msg_handler_; + my_listener_->listen(handler); + } + }; + + int stream_server_happy_path_test() { + reset_log_files(); + + stream_server s(1, 20010); + // send request + int count = 1000; + s.send_req(count); + + // check req + CHK_TRUE(s.waiting_for_responses()); + CHK_EQ(count, s.get_resp_log_index()); + CHK_EQ(count, s.get_next_log_index() - 1); + CHK_FALSE(s.is_msg_mismatch()); + + // stop + s.stop_server(); + TestSuite::sleep_sec(1, "shutting down"); + SimpleLogger::shutdown(); + return 0; + } +}; + +using namespace asio_service_stream_test; + +int main(int argc, char** argv) { + TestSuite ts(argc, argv); + ts.options.printTestMessage = true; + + ts.doTest("stream server happy path test", + stream_server_happy_path_test); + return 0; +}