From e43442dec57d9ecb1d88b8078f8d9ab4ea2b7c68 Mon Sep 17 00:00:00 2001 From: Alexander Bondarev Date: Mon, 25 Dec 2023 18:27:08 +0200 Subject: [PATCH 1/3] Socket processing improvements. 1. Add SimpleSocketSender friend class for SimpleSocket for avoiding concurrent state modifications during racing receive and send operations. 2. Don't log warning if socket shutdown initiated by client. 3. Make StatTimer thread local. --- .../main/wire/ByteBufferAsyncProcessor.cpp | 2 +- .../src/main/wire/SocketWire.cpp | 29 ++-- .../src/main/wire/SocketWire.h | 3 + rd-cpp/thirdparty/clsocket/CMakeLists.txt | 2 + .../thirdparty/clsocket/src/ActiveSocket.cpp | 15 +-- .../thirdparty/clsocket/src/PassiveSocket.cpp | 80 +++++------ .../thirdparty/clsocket/src/SimpleSocket.cpp | 124 +++++++----------- rd-cpp/thirdparty/clsocket/src/SimpleSocket.h | 19 +-- .../clsocket/src/SimpleSocketSender.cpp | 28 ++++ .../clsocket/src/SimpleSocketSender.h | 40 ++++++ rd-cpp/thirdparty/clsocket/src/StatTimer.h | 22 +++- 11 files changed, 208 insertions(+), 156 deletions(-) create mode 100644 rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.cpp create mode 100644 rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h diff --git a/rd-cpp/src/rd_framework_cpp/src/main/wire/ByteBufferAsyncProcessor.cpp b/rd-cpp/src/rd_framework_cpp/src/main/wire/ByteBufferAsyncProcessor.cpp index 087bc2407..0c6fc6c92 100644 --- a/rd-cpp/src/rd_framework_cpp/src/main/wire/ByteBufferAsyncProcessor.cpp +++ b/rd-cpp/src/rd_framework_cpp/src/main/wire/ByteBufferAsyncProcessor.cpp @@ -141,7 +141,7 @@ void ByteBufferAsyncProcessor::ThreadProc() return; } - while (data.empty() && queue.empty() || interrupt_balance != 0) + while ((data.empty() && queue.empty()) || interrupt_balance != 0) { if (state >= StateKind::Stopping) { diff --git a/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp b/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp index ef010deee..a3dd57b1b 100644 --- a/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp +++ b/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -81,16 +82,16 @@ bool SocketWire::Base::send0(Buffer::ByteArray const& msg, sequence_number_t seq send_package_header.write_integral(seqn); RD_ASSERT_THROW_MSG( - socket_provider->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH, + socket_sender->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH, this->id + ": failed to send header over the network" ", reason: " + - socket_provider->DescribeError()) + socket_sender->DescribeError()) - RD_ASSERT_THROW_MSG(socket_provider->Send(msg.data(), msglen) == msglen, this->id + + RD_ASSERT_THROW_MSG(socket_sender->Send(msg.data(), msglen) == msglen, this->id + ": failed to send package over the network" ", reason: " + - socket_provider->DescribeError()); + socket_sender->DescribeError()); logger->info("{}: were sent {} bytes", this->id, msglen); // RD_ASSERT_MSG(socketProvider->Flush(), "{}: failed to flush"); return true; @@ -126,6 +127,7 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr new_so { std::lock_guard guard(socket_send_lock); socket_provider = std::move(new_socket); + socket_sender = std::make_unique(socket_provider); socket_send_var.notify_all(); } { @@ -136,8 +138,8 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr new_so } } - auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) { - const auto heartbeat = start_heartbeat(heartbeatLifetime).share(); + const auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) { + const auto heartbeat = start_heartbeat(std::move(heartbeatLifetime)).share(); async_send_buffer.resume(); @@ -159,6 +161,11 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr new_so { logger->debug("{}: socket was already shut down", this->id); } + else if (socket_provider->GetSocketError() == CSimpleSocket::SocketNotconnected) + { + logger->debug("{}: socket not connected (shutdown likely was initiated by client)"); + socket_provider->Close(); + } else if (!socket_provider->Shutdown(CSimpleSocket::Both)) { // double close? @@ -393,14 +400,14 @@ void SocketWire::Base::ping() const ping_pkg_header.write_integral(counterpart_timestamp); { std::lock_guard guard(socket_send_lock); - int32_t sent = socket_provider->Send(ping_pkg_header.data(), ping_pkg_header.get_position()); - if (sent == 0 && !socket_provider->IsSocketValid()) + int32_t sent = socket_sender->Send(ping_pkg_header.data(), ping_pkg_header.get_position()); + if (sent == 0 && !socket_sender->IsSocketValid()) { logger->debug("{}: failed to send ping over the network, reason: socket was shut down for sending", this->id); return; } RD_ASSERT_THROW_MSG(sent == PACKAGE_HEADER_LENGTH, - fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_provider->DescribeError())) + fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_sender->DescribeError())) } ++current_timestamp; @@ -421,11 +428,11 @@ bool SocketWire::Base::send_ack(sequence_number_t seqn) const ack_buffer.write_integral(seqn); { std::lock_guard guard(socket_send_lock); - RD_ASSERT_THROW_MSG(socket_provider->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH, + RD_ASSERT_THROW_MSG(socket_sender->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH, this->id + ": failed to send ack over the network" ", reason: " + - socket_provider->DescribeError()) + socket_sender->DescribeError()) } return true; } diff --git a/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h b/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h index e71fe76cb..e8ad644b2 100644 --- a/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h +++ b/rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h @@ -15,6 +15,7 @@ class CSimpleSocket; class CActiveSocket; class CPassiveSocket; +class CSimpleSocketSender; namespace rd { @@ -37,6 +38,8 @@ class RD_FRAMEWORK_API SocketWire std::string id; IScheduler* scheduler = nullptr; std::shared_ptr socket_provider; + // we do use separate sender for socket_provider to avoid concurrent state modifications during contesting receive and send operations + std::unique_ptr socket_sender; std::shared_ptr socket; diff --git a/rd-cpp/thirdparty/clsocket/CMakeLists.txt b/rd-cpp/thirdparty/clsocket/CMakeLists.txt index 9e83852d8..5028fb5bd 100644 --- a/rd-cpp/thirdparty/clsocket/CMakeLists.txt +++ b/rd-cpp/thirdparty/clsocket/CMakeLists.txt @@ -19,12 +19,14 @@ SET(CLSOCKET_HEADERS src/PassiveSocket.h src/SimpleSocket.h src/StatTimer.h + src/SimpleSocketSender.h ) SET(CLSOCKET_SOURCES src/SimpleSocket.cpp src/ActiveSocket.cpp src/PassiveSocket.cpp + src/SimpleSocketSender.cpp ) # mark headers as headers... diff --git a/rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp b/rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp index dc6a6c527..6b19da3ba 100644 --- a/rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp +++ b/rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp @@ -90,8 +90,7 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort) // Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only. // //------------------------------------------------------------------ - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) == CSimpleSocket::SocketError) @@ -121,8 +120,6 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort) bRetVal = true; } - m_timer.SetEndTime(); - return bRetVal; } @@ -170,8 +167,7 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort) // Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only. // //------------------------------------------------------------------ - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError) { @@ -180,8 +176,6 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort) TranslateSocketError(); - m_timer.SetEndTime(); - return bRetVal; } @@ -228,8 +222,7 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort) // Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only. // //------------------------------------------------------------------ - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError) { @@ -238,8 +231,6 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort) TranslateSocketError(); - m_timer.SetEndTime(); - return bRetVal; } diff --git a/rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp b/rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp index 120383709..5482950a0 100644 --- a/rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp +++ b/rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp @@ -78,26 +78,23 @@ bool CPassiveSocket::BindMulticast(const char *pInterface, const char *pGroup, u //-------------------------------------------------------------------------- // Bind to the specified port //-------------------------------------------------------------------------- - if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) { - //---------------------------------------------------------------------- - // Join the multicast group - //---------------------------------------------------------------------- - m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup); - m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr; - - if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, - (void *) &m_stMulticastRequest, - sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) { - bRetVal = true; + { + CStatTimerCookie timer_cookie(timer); + if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) { + //---------------------------------------------------------------------- + // Join the multicast group + //---------------------------------------------------------------------- + m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup); + m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr; + + if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, + (void *) &m_stMulticastRequest, + sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) { + bRetVal = true; + } } - - m_timer.SetEndTime(); } - m_timer.Initialize(); - m_timer.SetStartTime(); - - //-------------------------------------------------------------------------- // If there was a new_socket error then close the new_socket to clean out the // connection in the backlog. @@ -152,30 +149,29 @@ bool CPassiveSocket::Listen(const char *pAddr, uint16_t nPort, int32_t nConnecti } } - m_timer.Initialize(); - m_timer.SetStartTime(); - - //-------------------------------------------------------------------------- - // Bind to the specified port - //-------------------------------------------------------------------------- - if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) != - CSimpleSocket::SocketError) { - socklen_t namelen = sizeof(m_stServerSockaddr); - if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) { - if (m_nSocketType == CSimpleSocket::SocketTypeTcp) { - if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) { + { + CStatTimerCookie timer_cookie(timer); + + //-------------------------------------------------------------------------- + // Bind to the specified port + //-------------------------------------------------------------------------- + if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) != + CSimpleSocket::SocketError) { + socklen_t namelen = sizeof(m_stServerSockaddr); + if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) { + if (m_nSocketType == CSimpleSocket::SocketTypeTcp) { + if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) { + bRetVal = true; + } + } else { bRetVal = true; } } else { - bRetVal = true; + bRetVal = false; } - } else { - bRetVal = false; } } - m_timer.SetEndTime(); - //-------------------------------------------------------------------------- // If there was a new_socket error then close the new_socket to clean out the // connection in the backlog. @@ -213,10 +209,9 @@ CActiveSocket *CPassiveSocket::Accept() { // Wait for incoming connection. //-------------------------------------------------------------------------- if (pClientSocket != NULL) { - CSocketError socketErrno = SocketSuccess; + CSocketError socketErrno; - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); nClientSockLen = sizeof(m_stClientSockaddr); @@ -246,8 +241,6 @@ CActiveSocket *CPassiveSocket::Accept() { } while (socketErrno == CSimpleSocket::SocketInterrupted); - m_timer.SetEndTime(); - if (socketErrno != CSimpleSocket::SocketSuccess) { delete pClientSocket; pClientSocket = NULL; @@ -271,13 +264,10 @@ int32_t CPassiveSocket::Send(const uint8_t *pBuf, size_t bytesToSend) { case CSimpleSocket::SocketTypeUdp: { if (IsSocketValid()) { if ((bytesToSend > 0) && (pBuf != NULL)) { - m_timer.Initialize(); - m_timer.SetStartTime(); - - m_nBytesSent = static_cast(SENDTO(m_socket, pBuf, bytesToSend, 0, - reinterpret_cast(&m_stClientSockaddr), sizeof(m_stClientSockaddr))); + CStatTimerCookie timer_cookie(timer); - m_timer.SetEndTime(); + m_nBytesSent = static_cast(SENDTO(m_socket, pBuf, bytesToSend, 0, + reinterpret_cast(&m_stClientSockaddr), sizeof(m_stClientSockaddr))); if (m_nBytesSent == CSimpleSocket::SocketError) { TranslateSocketError(); diff --git a/rd-cpp/thirdparty/clsocket/src/SimpleSocket.cpp b/rd-cpp/thirdparty/clsocket/src/SimpleSocket.cpp index fbc400082..59d713661 100644 --- a/rd-cpp/thirdparty/clsocket/src/SimpleSocket.cpp +++ b/rd-cpp/thirdparty/clsocket/src/SimpleSocket.cpp @@ -42,6 +42,8 @@ *----------------------------------------------------------------------------*/ #include "SimpleSocket.h" +thread_local CStatTimer CSimpleSocket::timer; + CSimpleSocket::CSimpleSocket(CSocketType nType) : m_socket(INVALID_SOCKET), m_socketErrno(CSimpleSocket::SocketInvalidSocket), @@ -148,10 +150,8 @@ bool CSimpleSocket::Initialize() //------------------------------------------------------------------------- // Create the basic Socket Handle //------------------------------------------------------------------------- - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); m_socket = socket(m_nSocketDomain, m_nSocketType, 0); - m_timer.SetEndTime(); TranslateSocketError(); @@ -392,8 +392,7 @@ int32_t CSimpleSocket::Send(const uint8_t *pBuf, size_t bytesToSend) { if ((bytesToSend > 0) && (pBuf != NULL)) { - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); //--------------------------------------------------------- // Check error condition and attempt to resend if call @@ -404,8 +403,6 @@ int32_t CSimpleSocket::Send(const uint8_t *pBuf, size_t bytesToSend) m_nBytesSent = static_cast(SEND(m_socket, pBuf, bytesToSend, 0)); TranslateSocketError(); } while (GetSocketError() == CSimpleSocket::SocketInterrupted); - - m_timer.SetEndTime(); } } break; @@ -416,8 +413,7 @@ int32_t CSimpleSocket::Send(const uint8_t *pBuf, size_t bytesToSend) { if ((bytesToSend > 0) && (pBuf != NULL)) { - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); //--------------------------------------------------------- // Check error condition and attempt to resend if call @@ -440,8 +436,6 @@ int32_t CSimpleSocket::Send(const uint8_t *pBuf, size_t bytesToSend) TranslateSocketError(); } while (GetSocketError() == CSimpleSocket::SocketInterrupted); } - - m_timer.SetEndTime(); } } break; @@ -477,12 +471,12 @@ bool CSimpleSocket::Close(void) //-------------------------------------------------------------------------- if (IsSocketValid()) { - Shutdown(Both); + Shutdown(Both); if (CLOSE(m_socket) != CSimpleSocket::SocketError) - { - m_socket = INVALID_SOCKET; - bRetVal = true; - } + { + m_socket = INVALID_SOCKET; + bRetVal = true; + } } TranslateSocketError(); @@ -750,8 +744,7 @@ int32_t CSimpleSocket::Receive(int32_t nMaxBytes, uint8_t * pBuffer ) SetSocketError(SocketSuccess); - m_timer.Initialize(); - m_timer.SetStartTime(); + CStatTimerCookie timer_cookie(timer); switch (m_nSocketType) { @@ -798,7 +791,6 @@ int32_t CSimpleSocket::Receive(int32_t nMaxBytes, uint8_t * pBuffer ) break; } - m_timer.SetEndTime(); TranslateSocketError(); //-------------------------------------------------------------------------- @@ -943,17 +935,20 @@ int32_t CSimpleSocket::SendFile(int32_t nOutFd, int32_t nInFd, off_t *pOffset, i // TranslateSocketError() - // //------------------------------------------------------------------------------ -void CSimpleSocket::TranslateSocketError(void) +void CSimpleSocket::TranslateSocketError() +{ + SetSocketError(TranslateLastSocketError()); +} + +CSimpleSocket::CSocketError CSimpleSocket::TranslateLastSocketError() { #if defined(__linux__) || defined(_DARWIN) switch (errno) { case EXIT_SUCCESS: - SetSocketError(CSimpleSocket::SocketSuccess); - break; + return SocketSuccess; case ENOTCONN: - SetSocketError(CSimpleSocket::SocketNotconnected); - break; + return SocketNotconnected; case ENOTSOCK: case EBADF: case EACCES: @@ -964,44 +959,32 @@ void CSimpleSocket::TranslateSocketError(void) case ENOMEM: case EPROTONOSUPPORT: case EPIPE: - SetSocketError(CSimpleSocket::SocketInvalidSocket); - break; + return SocketInvalidSocket; case ECONNREFUSED : - SetSocketError(CSimpleSocket::SocketConnectionRefused); - break; + return SocketConnectionRefused; case ETIMEDOUT: - SetSocketError(CSimpleSocket::SocketTimedout); - break; + return SocketTimedout; case EINPROGRESS: - SetSocketError(CSimpleSocket::SocketEinprogress); - break; + return SocketEinprogress; case EWOULDBLOCK: // case EAGAIN: - SetSocketError(CSimpleSocket::SocketEwouldblock); - break; + return SocketEwouldblock; case EINTR: - SetSocketError(CSimpleSocket::SocketInterrupted); - break; + return SocketInterrupted; case ECONNABORTED: - SetSocketError(CSimpleSocket::SocketConnectionAborted); - break; + return SocketConnectionAborted; case EINVAL: case EPROTO: - SetSocketError(CSimpleSocket::SocketProtocolError); - break; + return SocketProtocolError; case EPERM: - SetSocketError(CSimpleSocket::SocketFirewallError); - break; + return SocketFirewallError; case EFAULT: - SetSocketError(CSimpleSocket::SocketInvalidSocketBuffer); - break; + return SocketInvalidSocketBuffer; case ECONNRESET: case ENOPROTOOPT: - SetSocketError(CSimpleSocket::SocketConnectionReset); - break; + return SocketConnectionReset; default: - SetSocketError(CSimpleSocket::SocketEunknown); - break; + return SocketEunknown; } #endif #ifdef _WIN32 @@ -1009,56 +992,41 @@ void CSimpleSocket::TranslateSocketError(void) switch (nError) { case EXIT_SUCCESS: - SetSocketError(CSimpleSocket::SocketSuccess); - break; + return CSimpleSocket::SocketSuccess; case WSAEBADF: case WSAENOTCONN: - SetSocketError(CSimpleSocket::SocketNotconnected); - break; + return CSimpleSocket::SocketNotconnected; case WSAEINTR: - SetSocketError(CSimpleSocket::SocketInterrupted); - break; + return CSimpleSocket::SocketInterrupted; case WSAEACCES: case WSAEAFNOSUPPORT: case WSAEINVAL: case WSAEMFILE: case WSAENOBUFS: case WSAEPROTONOSUPPORT: - SetSocketError(CSimpleSocket::SocketInvalidSocket); - break; + return CSimpleSocket::SocketInvalidSocket; case WSAECONNREFUSED : - SetSocketError(CSimpleSocket::SocketConnectionRefused); - break; + return CSimpleSocket::SocketConnectionRefused; case WSAETIMEDOUT: - SetSocketError(CSimpleSocket::SocketTimedout); - break; + return CSimpleSocket::SocketTimedout; case WSAEINPROGRESS: - SetSocketError(CSimpleSocket::SocketEinprogress); - break; + return CSimpleSocket::SocketEinprogress; case WSAECONNABORTED: - SetSocketError(CSimpleSocket::SocketConnectionAborted); - break; + return CSimpleSocket::SocketConnectionAborted; case WSAEWOULDBLOCK: - SetSocketError(CSimpleSocket::SocketEwouldblock); - break; + return CSimpleSocket::SocketEwouldblock; case WSAENOTSOCK: - SetSocketError(CSimpleSocket::SocketInvalidSocket); - break; + return CSimpleSocket::SocketInvalidSocket; case WSAECONNRESET: - SetSocketError(CSimpleSocket::SocketConnectionReset); - break; + return CSimpleSocket::SocketConnectionReset; case WSANO_DATA: - SetSocketError(CSimpleSocket::SocketInvalidAddress); - break; + return CSimpleSocket::SocketInvalidAddress; case WSAEADDRINUSE: - SetSocketError(CSimpleSocket::SocketAddressInUse); - break; + return CSimpleSocket::SocketAddressInUse; case WSAEFAULT: - SetSocketError(CSimpleSocket::SocketInvalidPointer); - break; + return CSimpleSocket::SocketInvalidPointer; default: - SetSocketError(CSimpleSocket::SocketEunknown); - break; + return CSimpleSocket::SocketEunknown; } #endif } diff --git a/rd-cpp/thirdparty/clsocket/src/SimpleSocket.h b/rd-cpp/thirdparty/clsocket/src/SimpleSocket.h index 91f82e4d5..443a9408d 100644 --- a/rd-cpp/thirdparty/clsocket/src/SimpleSocket.h +++ b/rd-cpp/thirdparty/clsocket/src/SimpleSocket.h @@ -76,9 +76,9 @@ #ifdef _WIN32 #pragma warning( push ) #pragma warning( disable:4668 ) - #include - #include - #include + #include + #include + #include #pragma warning( pop ) #define IPTOS_LOWDELAY 0x10 @@ -205,12 +205,14 @@ class CSimpleSocket { /// @return true if the socket object contains a valid socket descriptor. virtual bool IsSocketValid(void) { return (m_socket != static_cast(SocketError)); - }; + } /// Provides a standard error code for cross platform development by /// mapping the operating system error to an error defined by the CSocket /// class. - void TranslateSocketError(void); + void TranslateSocketError(); + + static CSocketError TranslateLastSocketError(); /// Returns a human-readable description of the given error code /// or the last error code of a socket @@ -423,13 +425,13 @@ class CSimpleSocket { /// Get the total time the of the last operation in milliseconds. /// @return number of milliseconds of last operation. uint32_t GetTotalTimeMs() { - return m_timer.GetMilliSeconds(); + return timer.GetMilliSeconds(); }; /// Get the total time the of the last operation in microseconds. /// @return number of microseconds or last operation. uint32_t GetTotalTimeUsec() { - return m_timer.GetMicroSeconds(); + return timer.GetMicroSeconds(); }; /// Return Differentiated Services Code Point (DSCP) value currently set on the socket object. @@ -533,6 +535,7 @@ class CSimpleSocket { m_socket = socket; }; + friend class CSimpleSocketSender; private: /// Generic function used to get the send/receive window size /// @return zero on failure else the number of bytes of the TCP window size if successful. @@ -575,7 +578,7 @@ class CSimpleSocket { struct sockaddr_in m_stClientSockaddr; /// client address struct sockaddr_in m_stMulticastGroup; /// multicast group to bind to struct linger m_stLinger; /// linger flag - CStatTimer m_timer; /// internal statistics. + thread_local static CStatTimer timer; /// internal statistics. #ifdef _WIN32 WSADATA m_hWSAData; /// Windows #endif diff --git a/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.cpp b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.cpp new file mode 100644 index 000000000..9d40c100b --- /dev/null +++ b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.cpp @@ -0,0 +1,28 @@ +// +// Created by Alexander.Bondarev on 25/12/2023. +// + +#include "SimpleSocketSender.h" + +int32_t CSimpleSocketSender::Send(const uint8_t *pBuf, const size_t bytesToSend) const +{ + int32_t bytesSent = 0; + if (m_socket->IsSocketValid()) + { + if ((bytesToSend > 0) && (pBuf != nullptr)) + { + //--------------------------------------------------------- + // Check error condition and attempt to resend if call + // was interrupted by a signal. + //--------------------------------------------------------- + CSimpleSocket::CSocketError socket_error; + do + { + bytesSent = static_cast(SEND(m_socket->m_socket, pBuf, bytesToSend, 0)); + socket_error = CSimpleSocket::TranslateLastSocketError(); + } while (socket_error == CSimpleSocket::SocketInterrupted); + } + } + + return bytesSent; +} diff --git a/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h new file mode 100644 index 000000000..43c83fe86 --- /dev/null +++ b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h @@ -0,0 +1,40 @@ +#ifndef SIMPLESOCKETSENDER_H +#define SIMPLESOCKETSENDER_H + +#include + +#include "SimpleSocket.h" + +class CSimpleSocketSender +{ + std::shared_ptr m_socket; + CSimpleSocket::CSocketError m_error; + +public: + explicit CSimpleSocketSender(const std::shared_ptr& socket) : m_error(CSimpleSocket::SocketSuccess) + { + if (socket->m_nSocketType == CSimpleSocket::CSocketType::SocketTypeTcp) + m_socket = socket; + else + m_socket = std::make_shared(); + } + + int32_t Send(const uint8_t* pBuf, size_t bytesToSend) const; + + bool IsSocketValid() const + { + return m_socket->IsSocketValid(); + } + + CSimpleSocket::CSocketError GetSocketError() const + { + return m_error; + } + + const char* DescribeError() const + { + return CSimpleSocket::DescribeError(m_error); + } +}; + +#endif // SIMPLESOCKETSENDER_H diff --git a/rd-cpp/thirdparty/clsocket/src/StatTimer.h b/rd-cpp/thirdparty/clsocket/src/StatTimer.h index 06a54a995..6f0f68d8a 100644 --- a/rd-cpp/thirdparty/clsocket/src/StatTimer.h +++ b/rd-cpp/thirdparty/clsocket/src/StatTimer.h @@ -62,6 +62,8 @@ #include "Host.h" +#include + #if defined(_WIN32) #define GET_CLOCK_COUNT(x) QueryPerformanceCounter((LARGE_INTEGER *)x) #else @@ -109,7 +111,7 @@ class CStatTimer { private: uint32_t CalcTotalUSec() const { - return static_cast((m_endTime.tv_sec - m_startTime.tv_sec) * MICROSECONDS_CONVERSION + (m_endTime.tv_usec - m_startTime.tv_usec)); + return static_cast((m_endTime.tv_sec - m_startTime.tv_sec) * MICROSECONDS_CONVERSION + (m_endTime.tv_usec - m_startTime.tv_usec)); }; @@ -118,4 +120,22 @@ class CStatTimer { struct timeval m_endTime; }; +struct CStatTimerCookie +{ + CStatTimer& targetTimer; + CStatTimer timer; + + explicit CStatTimerCookie(CStatTimer& timer) : targetTimer(timer) + { + timer.Initialize(); + timer.SetStartTime(); + } + + ~CStatTimerCookie() + { + timer.SetEndTime(); + targetTimer = timer; + } +}; + #endif // __CSTATTIMER_H__ From 286e259f8fc976d166f561daf6a72c0f29af1a4b Mon Sep 17 00:00:00 2001 From: Alexander Bondarev Date: Wed, 3 Jan 2024 10:57:22 +0200 Subject: [PATCH 2/3] Throw runtime_error in SimpleSocketSender instead of dummy socket for non-TCP protocol. --- rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h | 8 ++++---- rd-cpp/thirdparty/clsocket/src/StatTimer.h | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h index 43c83fe86..a33ee5da7 100644 --- a/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h +++ b/rd-cpp/thirdparty/clsocket/src/SimpleSocketSender.h @@ -2,6 +2,7 @@ #define SIMPLESOCKETSENDER_H #include +#include #include "SimpleSocket.h" @@ -13,10 +14,9 @@ class CSimpleSocketSender public: explicit CSimpleSocketSender(const std::shared_ptr& socket) : m_error(CSimpleSocket::SocketSuccess) { - if (socket->m_nSocketType == CSimpleSocket::CSocketType::SocketTypeTcp) - m_socket = socket; - else - m_socket = std::make_shared(); + if (socket->m_nSocketType != CSimpleSocket::CSocketType::SocketTypeTcp) + throw std::runtime_error("Only TCP sockets are supported"); + m_socket = socket; } int32_t Send(const uint8_t* pBuf, size_t bytesToSend) const; diff --git a/rd-cpp/thirdparty/clsocket/src/StatTimer.h b/rd-cpp/thirdparty/clsocket/src/StatTimer.h index 6f0f68d8a..78255ed86 100644 --- a/rd-cpp/thirdparty/clsocket/src/StatTimer.h +++ b/rd-cpp/thirdparty/clsocket/src/StatTimer.h @@ -62,8 +62,6 @@ #include "Host.h" -#include - #if defined(_WIN32) #define GET_CLOCK_COUNT(x) QueryPerformanceCounter((LARGE_INTEGER *)x) #else From 8cdf969c4b15f383d7a808a3d2e179de5be2a8e1 Mon Sep 17 00:00:00 2001 From: Alexander Bondarev Date: Wed, 3 Jan 2024 12:09:19 +0200 Subject: [PATCH 3/3] Fix viewable_set_advice test invalid memory access (lambda callback with logView2 was called after logView2 already destroyed). --- rd-cpp/CMakeLists.txt | 8 ++++++++ rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp | 7 +++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/rd-cpp/CMakeLists.txt b/rd-cpp/CMakeLists.txt index 703d4e4bc..8c26b90d8 100644 --- a/rd-cpp/CMakeLists.txt +++ b/rd-cpp/CMakeLists.txt @@ -32,6 +32,14 @@ if (CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "7.0.0") message(FATAL_ERROR "Insufficient clang version") endif () + if (CMAKE_BUILD_TYPE MATCHES "Debug") + option(USE_ADDRESS_SANITIZER "Use address sanitizer to troubleshoot invalid allocations" ON) + else () + option(USE_ADDRESS_SANITIZER "Use address sanitizer to troubleshoot invalid allocations" OFF) + endif () + if (USE_ADDRESS_SANITIZER) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -g") + endif() endif () if (MINGW) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj") diff --git a/rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp b/rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp index 4753380d3..9bbc4d776 100644 --- a/rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp +++ b/rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp @@ -6,11 +6,10 @@ using namespace rd; TEST(viewable_set, advise) { - std::unique_ptr> set = std::make_unique>(); - std::vector logAdvise; std::vector logView1; std::vector logView2; + std::unique_ptr> set = std::make_unique>(); LifetimeDefinition::use([&](Lifetime lt) { set->advise(lt, [&](AddRemove kind, int const& v) { logAdvise.push_back(kind == AddRemove::ADD ? v : -v); }); set->view(lt, [&](Lifetime inner, int const& v) { @@ -66,8 +65,8 @@ TEST(viewable_set, view) std::unique_ptr> set = std::make_unique>(); std::vector log; - auto x = LifetimeDefinition::use([&](Lifetime lifetime) { - set->view(lifetime, [&](Lifetime lt, int const& value) { + auto x = LifetimeDefinition::use([&](const Lifetime& lifetime) { + set->view(lifetime, [&](const Lifetime& lt, int const& value) { log.push_back("View " + std::to_string(value)); lt->add_action([&]() { log.push_back("UnView " + std::to_string(value)); }); });