Skip to content

Commit

Permalink
Improve TCP connect timeout behaviour (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
LourensVeen committed Oct 18, 2024
1 parent 4369032 commit 7087662
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
54 changes: 50 additions & 4 deletions libmuscle/cpp/src/libmuscle/mcp/tcp_transport_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@

#include <sys/types.h>
#include <sys/socket.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <string.h>
#include <unistd.h>


namespace {
Expand All @@ -37,7 +42,10 @@ std::vector<std::string> split_location(std::string const & location) {
}


int connect(std::string const & address) {
int connect(std::string const & address, bool patient) {
int timeout = patient ? 3000 : 20000; // milliseconds
std::string errors;

std::size_t split = address.rfind(':');
std::string host = address.substr(0, split);
if (host.front() == '[') {
Expand Down Expand Up @@ -69,11 +77,37 @@ int connect(std::string const & address) {
int socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd == -1) continue;

int flags = fcntl(socket_fd, F_GETFL, 0);
fcntl(socket_fd, F_SETFL, flags | O_NONBLOCK);

err_code = connect(socket_fd, p->ai_addr, p->ai_addrlen);
if (err_code == -1) {
if ((err_code == -1) && (errno != EINPROGRESS)) {
::close(socket_fd);
continue;
}

struct pollfd pollfds;
pollfds.fd = socket_fd;
pollfds.events = POLLOUT;
pollfds.revents = 0;
err_code = poll(&pollfds, 1, timeout);

if (err_code == 0) {
::close(socket_fd);
continue;
}

// check if connect() actually succeeded
socklen_t len = sizeof(int);
getsockopt(socket_fd, SOL_SOCKET, SO_ERROR, &err_code, &len);
if (err_code != 0) {
::close(socket_fd);
continue;
}

flags = fcntl(socket_fd, F_GETFL, 0);
fcntl(socket_fd, F_SETFL, flags & ~O_NONBLOCK);

return socket_fd;
}

Expand All @@ -98,14 +132,26 @@ TcpTransportClient::TcpTransportClient(std::string const & location)

for (auto const & address: addresses)
try {
socket_fd_ = connect(address);
socket_fd_ = connect(address, false);
break;
}
catch (std::runtime_error const & e) {
errors += std::string(e.what()) + "\n";
continue;
}

if (socket_fd_ == -1) {
// None of our quick connection attempts worked. Either there's a network
// problem, or the server is very busy. Let's try again with more patience.
for (auto const & address: addresses)
try {
socket_fd_ = connect(address, true);
break;
}
catch (std::runtime_error const & e) {
errors += std::string(e.what()) + "\n";
}
}

if (socket_fd_ == -1)
throw std::runtime_error(
"Could not connect to any server at locations " + location
Expand Down
36 changes: 28 additions & 8 deletions libmuscle/python/libmuscle/mcp/tcp_transport_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from errno import ENOTCONN
import logging
import socket
from typing import Optional, Tuple

Expand All @@ -7,6 +8,9 @@
from libmuscle.profiling import ProfileTimestamp


_logger = logging.getLogger(__name__)


class TcpTransportClient(TransportClient):
"""A client that connects to a TCPTransport server.
"""
Expand Down Expand Up @@ -36,20 +40,35 @@ def __init__(self, location: str) -> None:
sock: Optional[socket.SocketType] = None
for address in addresses:
try:
sock = self._connect(address)
sock = self._connect(address, False)
break
except RuntimeError:
pass

if sock is None:
# None of our quick connection attempts worked. Either there's a network
# problem, or the server is very busy. Let's try again with more patience.
_logger.warning(
f'Could not immediately connect to {location}, trying again with'
' more patience. Please report this if it happens frequently.')

for address in addresses:
try:
sock = self._connect(address, True)
break
except RuntimeError:
pass

if sock is None:
_logger.error(f'Failed to connect also on the second try to {location}')
raise RuntimeError('Could not connect to the server at location'
' {}'.format(location))
else:
if hasattr(socket, "TCP_NODELAY"):
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "TCP_QUICKACK"):
sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
self._socket = sock

if hasattr(socket, "TCP_NODELAY"):
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "TCP_QUICKACK"):
sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
self._socket = sock

def call(self, request: bytes) -> Tuple[bytes, ProfileData]:
"""Send a request to the server and receive the response.
Expand Down Expand Up @@ -88,7 +107,7 @@ def close(self) -> None:
if e.errno != ENOTCONN:
raise

def _connect(self, address: str) -> socket.SocketType:
def _connect(self, address: str, patient: bool) -> socket.SocketType:
loc_parts = address.rsplit(':', 1)
host = loc_parts[0]
if host.startswith('['):
Expand All @@ -108,6 +127,7 @@ def _connect(self, address: str) -> socket.SocketType:
continue

try:
sock.settimeout(20.0 if patient else 3.0) # seconds
sock.connect(sockaddr)
except Exception:
sock.close()
Expand Down

0 comments on commit 7087662

Please sign in to comment.