Skip to content

Commit

Permalink
improve: connection management and error handling (#1986)
Browse files Browse the repository at this point in the history
### Connection Management Refinements and Error Handling Improvements

This PR implements several key changes, include:

#### Enhanced Error Handling
- More detailed logs for better debugging and error analysis.

#### Simplification and Refactoring
- Critical methods such as `Connection::close` and `Connection::accept`
have been simplified and refactored for clarity.

#### Precise Timeout Management
- `std::chrono::milliseconds` implemented for more accurate timeout
settings.

#### Improved Protocol Management
- Better handling of connection states and protocol management.

#### Robust Socket Closure
- Addition of `socket.cancel();` in `Connection::closeSocket()` ensures
cancellation of pending asynchronous operations, increasing safety
during socket closure.

#### Safe Message Sending
- State of the socket is now checked before sending messages, preventing
attempts to write on closed sockets.

#### Optimized Thread Safety
- Utilization of `std::scoped_lock` and `std::unique_lock` optimized for
safer thread handling.

#### Improved IP and Timeout Handling
- Enhancements in IP management and more efficient timeout handling.
  • Loading branch information
beats-dh authored Dec 9, 2023
1 parent 13ec0f2 commit 6b187fb
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 108 deletions.
207 changes: 101 additions & 106 deletions src/server/network/connection/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,98 +30,93 @@ void ConnectionManager::closeAll() {
try {
std::error_code error;
connection->socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
if (error) {
g_logger().error("[ConnectionManager::closeAll] - Failed to close connection, system error code {}", error.message());
}
} catch (const std::system_error &systemError) {
g_logger().error("[ConnectionManager::closeAll] - Failed to close connection, system error code {}", systemError.what());
g_logger().error("[ConnectionManager::closeAll] - Exception caught: {}", systemError.what());
}
});

connections.clear();
}

// Connection
// Constructor
Connection::Connection(asio::io_service &initIoService, ConstServicePort_ptr initservicePort) :
readTimer(initIoService),
writeTimer(initIoService),
service_port(std::move(initservicePort)),
socket(initIoService) {
timeConnected = time(nullptr);
socket(initIoService),
timeConnected(std::chrono::system_clock::to_time_t(std::chrono::system_clock::now())) {
}
// Constructor end

void Connection::close(bool force) {
// any thread
ConnectionManager::getInstance().releaseConnection(shared_from_this());

std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
ip = 0;

if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}
connectionState = CONNECTION_STATE_CLOSED;

if (protocol) {
g_dispatcher().addEvent(std::bind_front(&Protocol::release, protocol), "Protocol::release", 1000);
g_dispatcher().addEvent(std::bind_front(&Protocol::release, protocol), "Protocol::release", std::chrono::milliseconds(1000).count());
}

if (messageQueue.empty() || force) {
closeSocket();
} else {
// will be closed by the destructor or onWriteOperation
}
}

void Connection::closeSocket() {
if (socket.is_open()) {
try {
readTimer.cancel();
writeTimer.cancel();
std::error_code error;
socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
socket.close(error);
} catch (const std::system_error &e) {
g_logger().error("[Connection::closeSocket] - error: {}", e.what());
}
if (!socket.is_open()) {
return;
}

readTimer.cancel();
writeTimer.cancel();
socket.cancel();

std::error_code error;
socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
if (error) {
g_logger().error("[Connection::closeSocket] - Failed to shutdown socket: {}", error.message());
}

socket.close(error);
if (error) {
g_logger().error("[Connection::closeSocket] - Failed to close socket: {}", error.message());
}
}

void Connection::accept(Protocol_ptr protocolPtr) {
this->connectionState = CONNECTION_STATE_IDENTIFYING;
this->protocol = protocolPtr;
g_dispatcher().addEvent(std::bind_front(&Protocol::onConnect, protocolPtr), "Protocol::onConnect", 1000);
connectionState = CONNECTION_STATE_IDENTIFYING;
protocol = std::move(protocolPtr);
g_dispatcher().addEvent(std::bind_front(&Protocol::onConnect, protocol), "Protocol::onConnect", std::chrono::milliseconds(1000).count());

// Call second accept for not duplicate code
accept(false);
acceptInternal(false);
}

void Connection::accept(bool toggleParseHeader /* = true */) {
try {
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));
void Connection::acceptInternal(bool toggleParseHeader) {
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

// If toggleParseHeader is true, execute the parseHeader, if not, execute parseProxyIdentification
if (toggleParseHeader) {
// Read size of the first packet
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
} else {
// Read header bytes to identify if it is proxy identification
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseProxyIdentification, shared_from_this(), std::placeholders::_1));
}
} catch (const std::system_error &e) {
g_logger().error("[Connection::accept] - error: {}", e.what());
close(FORCE_CLOSE);
}
auto readCallback = toggleParseHeader ? std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1)
: std::bind(&Connection::parseProxyIdentification, shared_from_this(), std::placeholders::_1);
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), readCallback);
}

void Connection::parseProxyIdentification(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error) {
g_logger().error("[Connection::parseProxyIdentification] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

uint8_t* msgBuffer = msg.getBuffer();
Expand Down Expand Up @@ -163,23 +158,24 @@ void Connection::parseProxyIdentification(const std::error_code &error) {
}
}

accept(true);
acceptInternal(true);
}

void Connection::parseHeader(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error != asio::error::operation_aborted) {
g_logger().error("[Connection::parseHeader] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

uint32_t timePassed = std::max<uint32_t>(1, (time(nullptr) - timeConnected) + 1);
if ((++packetsSent / timePassed) > static_cast<uint32_t>(g_configManager().getNumber(MAX_PACKETS_PER_SECOND, __FUNCTION__))) {
g_logger().warn("{} disconnected for exceeding packet per second limit.", convertIPToString(getIP()));
g_logger().warn("[Connection::parseHeader] - {} disconnected for exceeding packet per second limit.", convertIPToString(getIP()));
close();
return;
}
Expand Down Expand Up @@ -209,14 +205,15 @@ void Connection::parseHeader(const std::error_code &error) {
}

void Connection::parsePacket(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error) {
g_logger().error("[Connection::parsePacket] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

bool skipReadingNextPacket = false;
Expand Down Expand Up @@ -275,104 +272,102 @@ void Connection::parsePacket(const std::error_code &error) {
}

void Connection::resumeWork() {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

try {
// Wait to the next packet
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
} catch (const std::system_error &e) {
g_logger().error("[Connection::resumeWork] - error: {}", e.what());
close(FORCE_CLOSE);
}
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
}

void Connection::send(const OutputMessage_ptr &outputMessage) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

bool noPendingWrite = messageQueue.empty();
messageQueue.emplace_back(outputMessage);
if (noPendingWrite) {
// Make asio thread handle xtea encryption instead of dispatcher
try {
if (socket.is_open()) {
asio::post(socket.get_executor(), std::bind(&Connection::internalWorker, shared_from_this()));
} catch (const std::system_error &e) {
g_logger().error("[Connection::send] - error: {}", e.what());
messageQueue.clear();
} else {
g_logger().error("[Connection::send] - Socket is not open for writing.");
close(FORCE_CLOSE);
}
}
}

void Connection::internalWorker() {
std::unique_lock<std::recursive_mutex> lockClass(connectionLock);
if (!messageQueue.empty()) {
const OutputMessage_ptr &outputMessage = messageQueue.front();
lockClass.unlock();
protocol->onSendMessage(outputMessage);
lockClass.lock();
internalSend(outputMessage);
} else if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
std::unique_lock lock(connectionLock);
if (messageQueue.empty()) {
if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
}
return;
}

const auto &outputMessage = messageQueue.front();
lock.unlock();
protocol->onSendMessage(outputMessage);
lock.lock();

internalSend(outputMessage);
}

uint32_t Connection::getIP() {
if (ip != 1) {
return ip;
std::scoped_lock lock(connectionLock);

if (ip == 1) {
std::error_code error;
asio::ip::tcp::endpoint endpoint = socket.remote_endpoint(error);
if (error) {
g_logger().error("[Connection::getIP] - Failed to get remote endpoint: {}", error.message());
ip = 0;
} else {
ip = htonl(endpoint.address().to_v4().to_uint());
}
}

std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);

// IP-address is expressed in network byte order
std::error_code error;
const asio::ip::tcp::endpoint endpoint = socket.remote_endpoint(error);
ip = error ? 0 : htonl(endpoint.address().to_v4().to_uint());
return ip;
}

void Connection::internalSend(const OutputMessage_ptr &outputMessage) {
try {
writeTimer.expires_from_now(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT));
writeTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));
writeTimer.expires_from_now(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT));
writeTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

asio::async_write(socket, asio::buffer(outputMessage->getOutputBuffer(), outputMessage->getLength()), std::bind(&Connection::onWriteOperation, shared_from_this(), std::placeholders::_1));
} catch (const std::system_error &e) {
g_logger().error("[Connection::internalSend] - error: {}", e.what());
}
asio::async_write(socket, asio::buffer(outputMessage->getOutputBuffer(), outputMessage->getLength()), std::bind(&Connection::onWriteOperation, shared_from_this(), std::placeholders::_1));
}

void Connection::onWriteOperation(const std::error_code &error) {
std::unique_lock<std::recursive_mutex> lockClass(connectionLock);
std::unique_lock lock(connectionLock);
writeTimer.cancel();
messageQueue.pop_front();

if (error) {
g_logger().error("[Connection::onWriteOperation] - Write error: {}", error.message());
messageQueue.clear();
close(FORCE_CLOSE);
return;
}

messageQueue.pop_front();

if (!messageQueue.empty()) {
const OutputMessage_ptr &outputMessage = messageQueue.front();
lockClass.unlock();
const auto &outputMessage = messageQueue.front();
lock.unlock();
protocol->onSendMessage(outputMessage);
lockClass.lock();
lock.lock();
internalSend(outputMessage);
} else if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
}
}

void Connection::handleTimeout(ConnectionWeak_ptr connectionWeak, const std::error_code &error) {
if (error == asio::error::operation_aborted) {
// The timer has been manually cancelled
return;
}

if (auto connection = connectionWeak.lock()) {
connection->close(FORCE_CLOSE);
if (error) {
if (error != asio::error::operation_aborted) {
g_logger().warn("[Connection::handleTimeout] - Timeout or error: {}", error.message());
if (auto connection = connectionWeak.lock()) {
connection->close(FORCE_CLOSE);
}
}
}
}
3 changes: 2 additions & 1 deletion src/server/network/connection/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ class Connection : public std::enable_shared_from_this<Connection> {
void close(bool force = false);
// Used by protocols that require server to send first
void accept(Protocol_ptr protocolPtr);
void accept(bool toggleParseHeader = true);
void acceptInternal(bool toggleParseHeader = true);

void resumeWork();

void send(const OutputMessage_ptr &outputMessage);

uint32_t getIP();
Expand Down
2 changes: 1 addition & 1 deletion src/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void ServicePort::onAccept(Connection_ptr connection, const std::error_code &err
if (service->is_single_socket()) {
connection->accept(service->make_protocol(connection));
} else {
connection->accept();
connection->acceptInternal();
}
} else {
connection->close(FORCE_CLOSE);
Expand Down

0 comments on commit 6b187fb

Please sign in to comment.