From 64476ea1b7d72b2f6313d4906487f2bbea08dfd2 Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 13 Jan 2025 21:13:14 +0000 Subject: [PATCH] Use RAII for loop registration (#27) There's probably a number of situations where registrations with the loop are left behind and cause issues. Instead, use RAII to avoid this. --- CMakeLists.txt | 1 + include/forwarder_connection.h | 10 ++-- include/i_loop.h | 73 +++++++++++++++++++++++++-- include/ip_lookup.h | 8 +++ include/loop.h | 14 +++--- include/server.h | 7 +-- include/vyatta_check.h | 3 ++ src/client_forwarders.cpp | 4 +- src/config_parser.cpp | 6 --- src/forwarder_config.cpp | 5 +- src/forwarder_connection.cpp | 90 +++++++++++++++++++--------------- src/i_loop.cpp | 59 ++++++++++++++++++++++ src/ip_lookup.cpp | 35 +++++++------ src/loop.cpp | 24 ++++----- src/openssl/context.cpp | 3 +- src/openssl/ssl_connection.cpp | 3 +- src/pid_file.cpp | 6 +-- src/server.cpp | 23 +++------ src/verify_cache.cpp | 4 +- src/vyatta_check.cpp | 10 +--- test/mock_loop.h | 6 +-- test/test_server.cpp | 2 +- 22 files changed, 263 insertions(+), 133 deletions(-) create mode 100644 src/i_loop.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cda18c9..52db301 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,6 +48,7 @@ set(CommonSources include/verify_cache.h src/verify_cache.cpp include/i_loop.h + src/i_loop.cpp include/loop.h src/loop.cpp include/server.h diff --git a/include/forwarder_connection.h b/include/forwarder_connection.h index ce2b577..1c53bd4 100644 --- a/include/forwarder_connection.h +++ b/include/forwarder_connection.h @@ -2,6 +2,7 @@ #pragma once #include "config_parser.h" +#include "i_loop.h" #include "openssl/ssl_connection.h" #include @@ -12,7 +13,6 @@ namespace dote { -class ILoop; class IForwarderConfig; class Socket; @@ -93,8 +93,6 @@ class ForwarderConnection void configureVerifier(); /// \brief Perform the initial connection - /// - /// \param handle The socket that is available to connect on void connect(int handle); /// \brief Nicely shutdown the connection @@ -136,6 +134,12 @@ class ForwarderConnection State m_state; /// The established connection to the forwarder std::shared_ptr m_socket; + /// The current read registration for m_socket. + ILoop::Registration m_read; + /// The current write registration for m_socket. + ILoop::Registration m_write; + /// The current exception registration for m_socket. + ILoop::Registration m_exception; /// The write buffer, the request will be a single message std::vector m_buffer; /// The chosen forwarder that this is connected to diff --git a/include/i_loop.h b/include/i_loop.h index 63e29ec..9bc2903 100644 --- a/include/i_loop.h +++ b/include/i_loop.h @@ -12,6 +12,72 @@ class ILoop public: virtual ~ILoop() = default; + /// \brief A registration type. + enum Type { + /// A special value to avoid deregistering on move. + Moved, + /// A read registration. + Read, + /// A write registration. + Write, + /// An exception registration. + Exception, + }; + + /// \brief This class is created with a read, write or exception + /// registration and automatically removes the registration + /// on destruction. + /// + /// NOTE: This object must not outlive the Loop that created it. + class Registration + { + public: + Registration(const Registration&) = delete; + Registration& operator=(const Registration&) = delete; + + /// \brief Construct an invalid Registration. + Registration(); + + /// \brief Create a registration for a given loop. + /// + /// \param loop The loop that this registration is for. + /// \param handle The handle of the registration + /// \param type The type of the registration. + Registration(ILoop* loop, int handle, Type type); + + /// \brief Take ownership of a registration. + /// + /// \param other The instance to take ownership from. + Registration(Registration&& other); + + /// \brief Take ownership of a registration. + /// + /// \param other The instance to take ownership from. + Registration& operator=(Registration&& other); + + /// \brief Deregister from the loop. + ~Registration(); + + operator bool() const { return valid(); } + + /// \brief Whether the registration is valid (i.e. the execution was + // successful). + bool valid() const; + + /// \brief Clear the registration. + void reset(); + + private: + friend class ILoop; + + /// The loop that this registration is for. + ILoop* m_loop; + /// The handle to deregister. + int m_handle; + /// The type of registration. + Type m_type; + }; + /// The type to call when the loop is available, called with the /// handle that caused it to be called using Callback = std::function; @@ -23,7 +89,7 @@ class ILoop /// \param timeout The time at which to call exception on the handle /// /// \return True if the handle is not already registered and now is - virtual bool registerRead(int handle, Callback callback, time_t timeout) = 0; + virtual Registration registerRead(int handle, Callback callback, time_t timeout) = 0; /// \brief Register for write availability on a given handle /// @@ -32,7 +98,7 @@ class ILoop /// \param timeout The time at which to call exception on the handle /// /// \return True if the handle is not already registered and now is - virtual bool registerWrite(int handle, Callback callback, time_t timeout) = 0; + virtual Registration registerWrite(int handle, Callback callback, time_t timeout) = 0; /// \brief Register for exceptions on a given handle /// @@ -40,8 +106,9 @@ class ILoop /// \param callback The callback to call if it triggers /// /// \return True if the handle is not already registered and now is - virtual bool registerException(int handle, Callback callback) = 0; + virtual Registration registerException(int handle, Callback callback) = 0; + protected: /// \brief Remove a read handle from the loop /// /// \param handle The handle to remove read handles for diff --git a/include/ip_lookup.h b/include/ip_lookup.h index e37014c..7b9ec6d 100644 --- a/include/ip_lookup.h +++ b/include/ip_lookup.h @@ -1,6 +1,8 @@ #pragma once +#include "i_loop.h" + #include namespace dote { @@ -50,6 +52,12 @@ class IpLookup std::shared_ptr m_loop; /// The socket that the connection if performed on std::shared_ptr m_socket; + /// The current read registration for m_socket. + ILoop::Registration m_read; + /// The current write registration for m_socket. + ILoop::Registration m_write; + /// The current exception registration for m_socket. + ILoop::Registration m_exception; /// The connection that was made to the IP address std::shared_ptr m_connection; /// The time at which the lookup will be considered failed diff --git a/include/loop.h b/include/loop.h index 4de8fa3..440b8f2 100644 --- a/include/loop.h +++ b/include/loop.h @@ -32,8 +32,8 @@ class Loop : public ILoop /// \param callback The callback to call if it triggers /// \param timeout The time at which to call exception on the handle /// - /// \return True if the handle is not already registered and now is - bool registerRead(int handle, Callback callback, time_t timeout) override; + /// \return A registration which is valid on success + Registration registerRead(int handle, Callback callback, time_t timeout) override; /// \brief Register for write availability on a given handle /// @@ -41,17 +41,18 @@ class Loop : public ILoop /// \param callback The callback to call if it triggers /// \param timeout The time at which to call exception on the handle /// - /// \return True if the handle is not already registered and now is - bool registerWrite(int handle, Callback callback, time_t timeout) override; + /// \return A registration which is valid on success + Registration registerWrite(int handle, Callback callback, time_t timeout) override; /// \brief Register for exceptions on a given handle /// /// \param handle The handle to register for exceptions /// \param callback The callback to call if it triggers /// - /// \return True if the handle is not already registered and now is - bool registerException(int handle, Callback callback) override; + /// \return A registration which is valid on success + Registration registerException(int handle, Callback callback) override; + private: /// \brief Remove a read handle from the loop /// /// \param handle The handle to remove read handles for @@ -67,7 +68,6 @@ class Loop : public ILoop /// \param handle The handle to remove exception handles for void removeException(int handle) override; - private: /// \brief Call a callback in a set of functions /// /// \param functions The functions to lookup the callback in diff --git a/include/server.h b/include/server.h index 23a66c6..dbed252 100644 --- a/include/server.h +++ b/include/server.h @@ -2,13 +2,13 @@ #pragma once #include "config_parser.h" +#include "i_loop.h" #include #include namespace dote { -class ILoop; class Socket; class IForwarders; @@ -46,8 +46,9 @@ class Server std::shared_ptr m_loop; /// The available forwarders std::shared_ptr m_forwarders; - /// The sockets that we are recieving from - std::vector> m_serverSockets; + using SocketAndRegistration = std::pair, ILoop::Registration>; + /// The sockets that we are recieving from and their read registrations. + std::vector m_serverSockets; }; } // namespace dote diff --git a/include/vyatta_check.h b/include/vyatta_check.h index 8251c9e..deab064 100644 --- a/include/vyatta_check.h +++ b/include/vyatta_check.h @@ -2,6 +2,7 @@ #pragma once #include "config_parser.h" +#include "i_loop.h" namespace dote { @@ -41,6 +42,8 @@ class VyattaCheck /// A handle to the inotify watch on the configuration file int m_fd; + /// The read registration for m_fd. + ILoop::Registration m_read; /// The main DoTe instance to configure on configuration changes Dote* m_dote; /// The base configuration to augment diff --git a/src/client_forwarders.cpp b/src/client_forwarders.cpp index 1d9083c..278ca6b 100644 --- a/src/client_forwarders.cpp +++ b/src/client_forwarders.cpp @@ -94,9 +94,7 @@ ClientForwarders::ClientForwarders(std::shared_ptr loop, m_loop(std::move(loop)), m_config(std::move(config)), m_ssl(std::move(ssl)), - m_maxConnections(maxConnections), - m_forwarders(), - m_queue() + m_maxConnections(maxConnections) { } ClientForwarders::~ClientForwarders() noexcept diff --git a/src/config_parser.cpp b/src/config_parser.cpp index 9f44d6b..01503db 100644 --- a/src/config_parser.cpp +++ b/src/config_parser.cpp @@ -23,13 +23,7 @@ constexpr std::size_t DEFAULT_MAX_CONNECTIONS = 5u; ConfigParser::ConfigParser() : m_valid(true), - m_partialForwarder(), - m_forwarders(), - m_servers(), - m_ipLookup(), - m_ciphers(), m_maxConnections(DEFAULT_MAX_CONNECTIONS), - m_pidFile(), m_daemonise(false), m_timeout(5u) { diff --git a/src/forwarder_config.cpp b/src/forwarder_config.cpp index ba01788..d360a8a 100644 --- a/src/forwarder_config.cpp +++ b/src/forwarder_config.cpp @@ -11,8 +11,7 @@ namespace dote { ForwarderConfig::ForwarderConfig() : - m_timeout(5), - m_forwarders() + m_timeout(5) { } void ForwarderConfig::clear() @@ -50,7 +49,7 @@ std::vector::const_iterator ForwarderConfig::end() cons return m_forwarders.cend(); } -void ForwarderConfig::setBad(const ConfigParser::Forwarder &config) +void ForwarderConfig::setBad(const ConfigParser::Forwarder& config) { // Look for the config in the forwarder list and move it to the end for (auto it = m_forwarders.begin(); it != m_forwarders.end(); ++it) diff --git a/src/forwarder_connection.cpp b/src/forwarder_connection.cpp index e7f3e3f..ce2b80d 100644 --- a/src/forwarder_connection.cpp +++ b/src/forwarder_connection.cpp @@ -18,12 +18,8 @@ ForwarderConnection::ForwarderConnection(std::shared_ptr loop, m_loop(std::move(loop)), m_config(std::move(config)), m_connection(ssl->create()), - m_incoming(), - m_shutdown(), m_state(CONNECTING), - m_socket(nullptr), - m_buffer(), - m_forwarder() + m_socket(nullptr) { auto chosen = m_config->get(); if (m_connection && chosen != m_config->end()) @@ -37,7 +33,7 @@ ForwarderConnection::ForwarderConnection(std::shared_ptr loop, if (m_socket) { m_connection->setSocket(m_socket->get()); - m_loop->registerException( + m_exception = m_loop->registerException( m_socket->get(), std::bind(&ForwarderConnection::exception, this, _1) ); @@ -92,34 +88,42 @@ bool ForwarderConnection::closed() void ForwarderConnection::connect(int handle) { - m_loop->removeWrite(handle); - m_loop->removeRead(handle); - switch (m_connection->connect()) { case openssl::SslConnection::Result::NEED_READ: - m_loop->registerRead( - m_socket->get(), - std::bind(&ForwarderConnection::connect, this, _1), - m_timeout - ); + if (!m_read) + { + m_read = m_loop->registerRead( + m_socket->get(), + std::bind(&ForwarderConnection::connect, this, _1), + m_timeout + ); + } + m_write.reset(); break; case openssl::SslConnection::Result::NEED_WRITE: - m_loop->registerWrite( - m_socket->get(), - std::bind(&ForwarderConnection::connect, this, _1), - m_timeout - ); + if (!m_write) + { + m_write = m_loop->registerWrite( + m_socket->get(), + std::bind(&ForwarderConnection::connect, this, _1), + m_timeout + ); + } + m_write.reset(); break; case openssl::SslConnection::Result::SUCCESS: - m_loop->registerRead( + // Remove the handlers to add the running ones. + m_read.reset(); + m_write.reset(); + m_read = m_loop->registerRead( m_socket->get(), std::bind(&ForwarderConnection::incoming, this, _1), m_timeout ); if (!m_buffer.empty()) { - m_loop->registerWrite( + m_write = m_loop->registerWrite( m_socket->get(), std::bind(&ForwarderConnection::outgoing, this, _1), m_timeout @@ -168,32 +172,38 @@ void ForwarderConnection::shutdown() { if (m_state == CONNECTING || m_state == OPEN) { + m_read.reset(); + m_write.reset(); _shutdown(m_socket->get()); } } void ForwarderConnection::_shutdown(int handle) { - m_loop->removeWrite(handle); - m_loop->removeRead(handle); - m_state = State::SHUTTING_DOWN; switch (m_connection->shutdown()) { case openssl::SslConnection::Result::NEED_READ: - m_loop->registerRead( - m_socket->get(), - std::bind(&ForwarderConnection::_shutdown, this, _1), - m_timeout - ); + if (!m_read) { + m_read = m_loop->registerRead( + m_socket->get(), + std::bind(&ForwarderConnection::_shutdown, this, _1), + m_timeout + ); + } + m_write.reset(); break; case openssl::SslConnection::Result::NEED_WRITE: - m_loop->registerWrite( - m_socket->get(), - std::bind(&ForwarderConnection::_shutdown, this, _1), - m_timeout - ); + if (!m_write) + { + m_write = m_loop->registerWrite( + m_socket->get(), + std::bind(&ForwarderConnection::_shutdown, this, _1), + m_timeout + ); + } + m_read.reset(); break; case openssl::SslConnection::Result::CLOSED: // Fall through @@ -214,9 +224,9 @@ bool ForwarderConnection::send(std::vector buffer) if (m_buffer.empty()) { - if (m_state == State::OPEN) + if (m_state == State::OPEN && !m_write) { - m_loop->registerWrite( + m_write = m_loop->registerWrite( m_socket->get(), std::bind(&ForwarderConnection::outgoing, this, _1), m_timeout @@ -241,7 +251,7 @@ void ForwarderConnection::outgoing(int handle) break; case openssl::SslConnection::Result::SUCCESS: m_buffer.clear(); - m_loop->removeWrite(handle); + m_write.reset(); break; case openssl::SslConnection::Result::FATAL: Log::notice << "Error writing to forwarder"; @@ -268,9 +278,9 @@ void ForwarderConnection::close() if (m_socket) { int handle = m_socket->get(); - m_loop->removeRead(handle); - m_loop->removeWrite(handle); - m_loop->removeException(handle); + m_read.reset(); + m_write.reset(); + m_exception.reset(); m_state = State::CLOSED; m_socket.reset(); diff --git a/src/i_loop.cpp b/src/i_loop.cpp new file mode 100644 index 0000000..91c855d --- /dev/null +++ b/src/i_loop.cpp @@ -0,0 +1,59 @@ +#include "i_loop.h" + +namespace dote { + +ILoop::Registration::Registration(ILoop* loop, int handle, Type type) : + m_loop(loop), + m_handle(handle), + m_type(type) +{ } + +ILoop::Registration::Registration() : + m_loop(nullptr), + m_handle(-1), + m_type(Type::Moved) +{ } + +ILoop::Registration::Registration(Registration&& other) : + m_loop(other.m_loop), + m_handle(other.m_handle), + m_type(other.m_type) +{ + other.m_type = Type::Moved; +} + +ILoop::Registration& ILoop::Registration::operator=(Registration&& other) +{ + std::swap(m_loop, other.m_loop); + std::swap(m_handle, other.m_handle); + std::swap(m_type, other.m_type); + return *this; +} + +ILoop::Registration::~Registration() +{ + reset(); +} + +void ILoop::Registration::reset() { + switch (m_type) { + case Moved: + break; + case Read: + m_loop->removeRead(m_handle); + break; + case Write: + m_loop->removeWrite(m_handle); + break; + case Exception: + m_loop->removeException(m_handle); + break; + } + m_type = Type::Moved; +} + +bool ILoop::Registration::valid() const { + return m_type != Type::Moved; +} + +} // namespace dote diff --git a/src/ip_lookup.cpp b/src/ip_lookup.cpp index 1d57e76..780f9db 100644 --- a/src/ip_lookup.cpp +++ b/src/ip_lookup.cpp @@ -22,7 +22,7 @@ IpLookup::IpLookup(const ConfigParser& config) : ); m_connection = std::make_shared(context); m_connection->setSocket(m_socket->get()); - m_loop->registerException( + m_exception = m_loop->registerException( m_socket->get(), std::bind(&IpLookup::connect, this, _1) ); @@ -37,27 +37,32 @@ IpLookup::~IpLookup() void IpLookup::connect(int handle) { - m_loop->removeWrite(handle); - m_loop->removeRead(handle); - switch (m_connection->connect()) { case openssl::SslConnection::Result::NEED_READ: - m_loop->registerRead( - m_socket->get(), - std::bind(&IpLookup::connect, this, _1), - m_timeout - ); + if (!m_read) + { + m_read = m_loop->registerRead( + m_socket->get(), + std::bind(&IpLookup::connect, this, _1), + m_timeout + ); + } + m_write.reset(); break; case openssl::SslConnection::Result::NEED_WRITE: - m_loop->registerWrite( - m_socket->get(), - std::bind(&IpLookup::connect, this, _1), - m_timeout - ); + if (!m_write) + { + m_write = m_loop->registerWrite( + m_socket->get(), + std::bind(&IpLookup::connect, this, _1), + m_timeout + ); + } + m_read.reset(); break; default: - m_loop->removeException(handle); + m_exception.reset(); break; } } diff --git a/src/loop.cpp b/src/loop.cpp index 7b46cb9..aadbaf9 100644 --- a/src/loop.cpp +++ b/src/loop.cpp @@ -48,12 +48,12 @@ void Loop::popluateFds(std::vector& fds, const std::map& functions, short event) { - for (auto& read : functions) + for (auto& handle_function : functions) { bool found = false; for (auto& fd : fds) { - if (fd.fd == read.first) + if (fd.fd == handle_function.first) { fd.events |= event; found = true; @@ -61,7 +61,7 @@ void Loop::popluateFds(std::vector& fds, } if (!found) { - fds.emplace_back(pollfd { read.first, event, 0 }); + fds.emplace_back(pollfd { handle_function.first, event, 0 }); } } } @@ -101,38 +101,38 @@ void Loop::callCallback( } } -bool Loop::registerRead(int handle, Callback callback, time_t timeout) +ILoop::Registration Loop::registerRead(int handle, Callback callback, time_t timeout) { if (m_readFunctions.count(handle)) { - return false; + return {}; } m_readFunctions.insert({ handle, std::make_pair(std::move(callback), timeout) }); - return true; + return Registration(this, handle, Type::Read); } -bool Loop::registerWrite(int handle, Callback callback, time_t timeout) +ILoop::Registration Loop::registerWrite(int handle, Callback callback, time_t timeout) { if (m_writeFunctions.count(handle)) { - return false; + return {}; } m_writeFunctions.insert({ handle, std::make_pair(std::move(callback), timeout) }); - return true; + return Registration(this, handle, Type::Write); } -bool Loop::registerException(int handle, Callback callback) +ILoop::Registration Loop::registerException(int handle, Callback callback) { if (m_exceptFunctions.count(handle)) { - return false; + return {}; } // Nothing to register for, poll always returns exceptions m_exceptFunctions.insert({ handle, std::move(callback) }); - return true; + return Registration(this, handle, Type::Exception); } void Loop::removeRead(int handle) diff --git a/src/openssl/context.cpp b/src/openssl/context.cpp index 541ada5..05a5fb3 100644 --- a/src/openssl/context.cpp +++ b/src/openssl/context.cpp @@ -36,8 +36,7 @@ constexpr int ALLOWED_ERRORS[] = { Context::Context(const std::string& ciphers) : m_context(nullptr), - m_session(nullptr), - m_chainVerifier() + m_session(nullptr) { // Flag to track if we've tried initialising the OpenSSL // library yet because we shouldn't keep trying diff --git a/src/openssl/ssl_connection.cpp b/src/openssl/ssl_connection.cpp index 5b65b3b..cc70caf 100644 --- a/src/openssl/ssl_connection.cpp +++ b/src/openssl/ssl_connection.cpp @@ -44,8 +44,7 @@ T certificateOperation(T(CertificateUtilities::*utility)(), SSL* ssl) SslConnection::SslConnection(std::shared_ptr context) : m_context(std::move(context)), - m_ssl(nullptr), - m_verifier() + m_ssl(nullptr) { if (m_context) { diff --git a/src/pid_file.cpp b/src/pid_file.cpp index ea23404..12e0a2a 100644 --- a/src/pid_file.cpp +++ b/src/pid_file.cpp @@ -18,11 +18,11 @@ PidFile::PidFile(const std::string& filename) : m_handle = open(filename.c_str(), O_RDWR | O_CREAT, 0640); if (m_handle < 0) { - dote::Log::err << "Unable to open PID file"; + Log::err << "Unable to open PID file"; } else if (lockf(m_handle, F_TLOCK, 0) < 0) { - dote::Log::err << "Unable to lock PID file"; + Log::err << "Unable to lock PID file"; close(m_handle); m_handle = -1; } @@ -34,7 +34,7 @@ PidFile::PidFile(const std::string& filename) : if (write(m_handle, contents.c_str(), contents.length()) != contents.length()) { - dote::Log::err << "Unable to write the PID to the PID file"; + Log::err << "Unable to write the PID to the PID file"; (void) close(m_handle); m_handle = -1; (void) unlink(m_filename.c_str()); diff --git a/src/server.cpp b/src/server.cpp index 9d76a84..17951f7 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -76,17 +76,10 @@ using namespace std::placeholders; Server::Server(std::shared_ptr loop, std::shared_ptr forwarders) : m_loop(std::move(loop)), - m_forwarders(std::move(forwarders)), - m_serverSockets() + m_forwarders(std::move(forwarders)) { } -Server::~Server() -{ - for (auto& socket : m_serverSockets) - { - m_loop->removeRead(socket->get()); - } -} +Server::~Server() = default; bool Server::addServer(const ConfigParser::Server& config) { @@ -99,12 +92,12 @@ bool Server::addServer(const ConfigParser::Server& config) { Log::warn << "Unable to get recieve address for packets"; } - m_serverSockets.emplace_back(std::move(serverSocket)); - m_loop->registerRead( - m_serverSockets.back()->get(), + auto registration = m_loop->registerRead( + serverSocket->get(), std::bind(&Server::handleDnsRequest, this, _1), 0 ); + m_serverSockets.emplace_back(std::move(serverSocket), std::move(registration)); return true; } @@ -112,11 +105,11 @@ void Server::handleDnsRequest(int handle) { // Get the socket for this handle std::shared_ptr handleSocket; - for (auto& socket : m_serverSockets) + for (auto& socket_registration : m_serverSockets) { - if (socket->get() == handle) + if (socket_registration.first->get() == handle) { - handleSocket = socket; + handleSocket = socket_registration.first; break; } } diff --git a/src/verify_cache.cpp b/src/verify_cache.cpp index 3c93fde..c454a6c 100644 --- a/src/verify_cache.cpp +++ b/src/verify_cache.cpp @@ -9,9 +9,7 @@ namespace dote { VerifyCache::VerifyCache(openssl::Context::Verifier verifier, int timeout) : m_verifier(verifier), - m_timeout(timeout), - m_cache(), - m_expiry() + m_timeout(timeout) { } int VerifyCache::forwardVerify(X509_STORE_CTX* context) diff --git a/src/vyatta_check.cpp b/src/vyatta_check.cpp index 62de30d..411df3f 100644 --- a/src/vyatta_check.cpp +++ b/src/vyatta_check.cpp @@ -42,14 +42,6 @@ VyattaCheck::~VyattaCheck() { if (m_fd >= 0) { - if (m_dote) - { - auto loop = m_dote->looper(); - if (loop) - { - loop->removeRead(m_fd); - } - } close(m_fd); } } @@ -69,7 +61,7 @@ void VyattaCheck::configure(Dote& dote) if (loop && m_fd >= 0) { m_dote = &dote; - loop->registerRead( + m_read = loop->registerRead( m_fd, std::bind(&VyattaCheck::handleRead, this, std::placeholders::_1), 0u diff --git a/test/mock_loop.h b/test/mock_loop.h index c6da997..b8e1835 100644 --- a/test/mock_loop.h +++ b/test/mock_loop.h @@ -13,9 +13,9 @@ class MockLoop : public ILoop ~MockLoop() noexcept { } - MOCK_METHOD3(registerRead, bool(int, Callback, time_t)); - MOCK_METHOD3(registerWrite, bool(int, Callback, time_t)); - MOCK_METHOD2(registerException, bool(int, Callback)); + MOCK_METHOD3(registerRead, ILoop::Registration(int, Callback, time_t)); + MOCK_METHOD3(registerWrite, ILoop::Registration(int, Callback, time_t)); + MOCK_METHOD2(registerException, ILoop::Registration(int, Callback)); MOCK_METHOD1(removeRead, void(int)); MOCK_METHOD1(removeWrite, void(int)); MOCK_METHOD1(removeException, void(int)); diff --git a/test/test_server.cpp b/test/test_server.cpp index 8b37ba6..0ddaabe 100644 --- a/test/test_server.cpp +++ b/test/test_server.cpp @@ -43,7 +43,7 @@ class TestServer : public ::testing::Test .WillOnce(Invoke([this](int handle, ILoop::Callback callback, time_t timeout) { m_handle = handle; m_callback = std::move(callback); - return true; + return ILoop::Registration(m_loop.get(), m_handle, ILoop::Read); })); ASSERT_TRUE(m_server.addServer(m_config)); ASSERT_TRUE(m_callback);