From 7087662f89de913c9b102f8322546210a11f6691 Mon Sep 17 00:00:00 2001 From: Lourens Veen Date: Fri, 18 Oct 2024 10:25:49 +0200 Subject: [PATCH] Improve TCP connect timeout behaviour (#106) --- .../libmuscle/mcp/tcp_transport_client.cpp | 54 +++++++++++++++++-- .../libmuscle/mcp/tcp_transport_client.py | 36 ++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/libmuscle/cpp/src/libmuscle/mcp/tcp_transport_client.cpp b/libmuscle/cpp/src/libmuscle/mcp/tcp_transport_client.cpp index c2551a3b..c3fb2c39 100644 --- a/libmuscle/cpp/src/libmuscle/mcp/tcp_transport_client.cpp +++ b/libmuscle/cpp/src/libmuscle/mcp/tcp_transport_client.cpp @@ -11,8 +11,13 @@ #include #include +#include +#include #include #include +#include +#include +#include namespace { @@ -37,7 +42,10 @@ std::vector split_location(std::string const & location) { } -int connect(std::string const & address) { +int connect(std::string const & address, bool patient) { + int timeout = patient ? 3000 : 20000; // milliseconds + std::string errors; + std::size_t split = address.rfind(':'); std::string host = address.substr(0, split); if (host.front() == '[') { @@ -69,11 +77,37 @@ int connect(std::string const & address) { int socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); if (socket_fd == -1) continue; + int flags = fcntl(socket_fd, F_GETFL, 0); + fcntl(socket_fd, F_SETFL, flags | O_NONBLOCK); + err_code = connect(socket_fd, p->ai_addr, p->ai_addrlen); - if (err_code == -1) { + if ((err_code == -1) && (errno != EINPROGRESS)) { + ::close(socket_fd); + continue; + } + + struct pollfd pollfds; + pollfds.fd = socket_fd; + pollfds.events = POLLOUT; + pollfds.revents = 0; + err_code = poll(&pollfds, 1, timeout); + + if (err_code == 0) { ::close(socket_fd); continue; } + + // check if connect() actually succeeded + socklen_t len = sizeof(int); + getsockopt(socket_fd, SOL_SOCKET, SO_ERROR, &err_code, &len); + if (err_code != 0) { + ::close(socket_fd); + continue; + } + + flags = fcntl(socket_fd, F_GETFL, 0); + fcntl(socket_fd, F_SETFL, flags & ~O_NONBLOCK); + return socket_fd; } @@ -98,14 +132,26 @@ TcpTransportClient::TcpTransportClient(std::string const & location) for (auto const & address: addresses) try { - socket_fd_ = connect(address); + socket_fd_ = connect(address, false); break; } catch (std::runtime_error const & e) { errors += std::string(e.what()) + "\n"; - continue; } + if (socket_fd_ == -1) { + // None of our quick connection attempts worked. Either there's a network + // problem, or the server is very busy. Let's try again with more patience. + for (auto const & address: addresses) + try { + socket_fd_ = connect(address, true); + break; + } + catch (std::runtime_error const & e) { + errors += std::string(e.what()) + "\n"; + } + } + if (socket_fd_ == -1) throw std::runtime_error( "Could not connect to any server at locations " + location diff --git a/libmuscle/python/libmuscle/mcp/tcp_transport_client.py b/libmuscle/python/libmuscle/mcp/tcp_transport_client.py index ed9536d2..85cb0bc5 100644 --- a/libmuscle/python/libmuscle/mcp/tcp_transport_client.py +++ b/libmuscle/python/libmuscle/mcp/tcp_transport_client.py @@ -1,4 +1,5 @@ from errno import ENOTCONN +import logging import socket from typing import Optional, Tuple @@ -7,6 +8,9 @@ from libmuscle.profiling import ProfileTimestamp +_logger = logging.getLogger(__name__) + + class TcpTransportClient(TransportClient): """A client that connects to a TCPTransport server. """ @@ -36,20 +40,35 @@ def __init__(self, location: str) -> None: sock: Optional[socket.SocketType] = None for address in addresses: try: - sock = self._connect(address) + sock = self._connect(address, False) break except RuntimeError: pass if sock is None: + # None of our quick connection attempts worked. Either there's a network + # problem, or the server is very busy. Let's try again with more patience. + _logger.warning( + f'Could not immediately connect to {location}, trying again with' + ' more patience. Please report this if it happens frequently.') + + for address in addresses: + try: + sock = self._connect(address, True) + break + except RuntimeError: + pass + + if sock is None: + _logger.error(f'Failed to connect also on the second try to {location}') raise RuntimeError('Could not connect to the server at location' ' {}'.format(location)) - else: - if hasattr(socket, "TCP_NODELAY"): - sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - if hasattr(socket, "TCP_QUICKACK"): - sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) - self._socket = sock + + if hasattr(socket, "TCP_NODELAY"): + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + if hasattr(socket, "TCP_QUICKACK"): + sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + self._socket = sock def call(self, request: bytes) -> Tuple[bytes, ProfileData]: """Send a request to the server and receive the response. @@ -88,7 +107,7 @@ def close(self) -> None: if e.errno != ENOTCONN: raise - def _connect(self, address: str) -> socket.SocketType: + def _connect(self, address: str, patient: bool) -> socket.SocketType: loc_parts = address.rsplit(':', 1) host = loc_parts[0] if host.startswith('['): @@ -108,6 +127,7 @@ def _connect(self, address: str) -> socket.SocketType: continue try: + sock.settimeout(20.0 if patient else 3.0) # seconds sock.connect(sockaddr) except Exception: sock.close()