Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: connection management and error handling #1986

Merged
merged 2 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 101 additions & 106 deletions src/server/network/connection/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,98 +30,93 @@ void ConnectionManager::closeAll() {
try {
std::error_code error;
connection->socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
if (error) {
g_logger().error("[ConnectionManager::closeAll] - Failed to close connection, system error code {}", error.message());
}
} catch (const std::system_error &systemError) {
g_logger().error("[ConnectionManager::closeAll] - Failed to close connection, system error code {}", systemError.what());
g_logger().error("[ConnectionManager::closeAll] - Exception caught: {}", systemError.what());
}
});

connections.clear();
}

// Connection
// Constructor
Connection::Connection(asio::io_service &initIoService, ConstServicePort_ptr initservicePort) :
readTimer(initIoService),
writeTimer(initIoService),
service_port(std::move(initservicePort)),
socket(initIoService) {
timeConnected = time(nullptr);
socket(initIoService),
timeConnected(std::chrono::system_clock::to_time_t(std::chrono::system_clock::now())) {
}
// Constructor end

void Connection::close(bool force) {
// any thread
ConnectionManager::getInstance().releaseConnection(shared_from_this());

std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
ip = 0;

if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}
connectionState = CONNECTION_STATE_CLOSED;

if (protocol) {
g_dispatcher().addEvent(std::bind_front(&Protocol::release, protocol), "Protocol::release", 1000);
g_dispatcher().addEvent(std::bind_front(&Protocol::release, protocol), "Protocol::release", std::chrono::milliseconds(1000).count());
}

if (messageQueue.empty() || force) {
closeSocket();
} else {
// will be closed by the destructor or onWriteOperation
}
}

void Connection::closeSocket() {
if (socket.is_open()) {
try {
readTimer.cancel();
writeTimer.cancel();
std::error_code error;
socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
socket.close(error);
} catch (const std::system_error &e) {
g_logger().error("[Connection::closeSocket] - error: {}", e.what());
}
if (!socket.is_open()) {
return;
}

readTimer.cancel();
writeTimer.cancel();
socket.cancel();

std::error_code error;
socket.shutdown(asio::ip::tcp::socket::shutdown_both, error);
if (error) {
g_logger().error("[Connection::closeSocket] - Failed to shutdown socket: {}", error.message());
}

socket.close(error);
if (error) {
g_logger().error("[Connection::closeSocket] - Failed to close socket: {}", error.message());
}
}

void Connection::accept(Protocol_ptr protocolPtr) {
this->connectionState = CONNECTION_STATE_IDENTIFYING;
this->protocol = protocolPtr;
g_dispatcher().addEvent(std::bind_front(&Protocol::onConnect, protocolPtr), "Protocol::onConnect", 1000);
connectionState = CONNECTION_STATE_IDENTIFYING;
protocol = std::move(protocolPtr);
g_dispatcher().addEvent(std::bind_front(&Protocol::onConnect, protocol), "Protocol::onConnect", std::chrono::milliseconds(1000).count());

// Call second accept for not duplicate code
accept(false);
acceptInternal(false);
}

void Connection::accept(bool toggleParseHeader /* = true */) {
try {
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));
void Connection::acceptInternal(bool toggleParseHeader) {
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

// If toggleParseHeader is true, execute the parseHeader, if not, execute parseProxyIdentification
if (toggleParseHeader) {
// Read size of the first packet
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
} else {
// Read header bytes to identify if it is proxy identification
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseProxyIdentification, shared_from_this(), std::placeholders::_1));
}
} catch (const std::system_error &e) {
g_logger().error("[Connection::accept] - error: {}", e.what());
close(FORCE_CLOSE);
}
auto readCallback = toggleParseHeader ? std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1)
: std::bind(&Connection::parseProxyIdentification, shared_from_this(), std::placeholders::_1);
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), readCallback);
}

void Connection::parseProxyIdentification(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error) {
g_logger().error("[Connection::parseProxyIdentification] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

uint8_t* msgBuffer = msg.getBuffer();
Expand Down Expand Up @@ -163,23 +158,24 @@ void Connection::parseProxyIdentification(const std::error_code &error) {
}
}

accept(true);
acceptInternal(true);
}

void Connection::parseHeader(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error != asio::error::operation_aborted) {
g_logger().error("[Connection::parseHeader] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

uint32_t timePassed = std::max<uint32_t>(1, (time(nullptr) - timeConnected) + 1);
if ((++packetsSent / timePassed) > static_cast<uint32_t>(g_configManager().getNumber(MAX_PACKETS_PER_SECOND, __FUNCTION__))) {
g_logger().warn("{} disconnected for exceeding packet per second limit.", convertIPToString(getIP()));
g_logger().warn("[Connection::parseHeader] - {} disconnected for exceeding packet per second limit.", convertIPToString(getIP()));
close();
return;
}
Expand Down Expand Up @@ -209,14 +205,15 @@ void Connection::parseHeader(const std::error_code &error) {
}

void Connection::parsePacket(const std::error_code &error) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
readTimer.cancel();

if (error) {
if (error || connectionState == CONNECTION_STATE_CLOSED) {
if (error) {
g_logger().error("[Connection::parsePacket] - Read error: {}", error.message());
}
close(FORCE_CLOSE);
return;
} else if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

bool skipReadingNextPacket = false;
Expand Down Expand Up @@ -275,104 +272,102 @@ void Connection::parsePacket(const std::error_code &error) {
}

void Connection::resumeWork() {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
readTimer.expires_from_now(std::chrono::seconds(CONNECTION_READ_TIMEOUT));
readTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

try {
// Wait to the next packet
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
} catch (const std::system_error &e) {
g_logger().error("[Connection::resumeWork] - error: {}", e.what());
close(FORCE_CLOSE);
}
asio::async_read(socket, asio::buffer(msg.getBuffer(), HEADER_LENGTH), std::bind(&Connection::parseHeader, shared_from_this(), std::placeholders::_1));
}

void Connection::send(const OutputMessage_ptr &outputMessage) {
std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);
std::scoped_lock lock(connectionLock);
if (connectionState == CONNECTION_STATE_CLOSED) {
return;
}

bool noPendingWrite = messageQueue.empty();
messageQueue.emplace_back(outputMessage);
if (noPendingWrite) {
// Make asio thread handle xtea encryption instead of dispatcher
try {
if (socket.is_open()) {
asio::post(socket.get_executor(), std::bind(&Connection::internalWorker, shared_from_this()));
} catch (const std::system_error &e) {
g_logger().error("[Connection::send] - error: {}", e.what());
messageQueue.clear();
} else {
g_logger().error("[Connection::send] - Socket is not open for writing.");
close(FORCE_CLOSE);
}
}
}

void Connection::internalWorker() {
std::unique_lock<std::recursive_mutex> lockClass(connectionLock);
if (!messageQueue.empty()) {
const OutputMessage_ptr &outputMessage = messageQueue.front();
lockClass.unlock();
protocol->onSendMessage(outputMessage);
lockClass.lock();
internalSend(outputMessage);
} else if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
std::unique_lock lock(connectionLock);
if (messageQueue.empty()) {
if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
}
return;
}

const auto &outputMessage = messageQueue.front();
lock.unlock();
protocol->onSendMessage(outputMessage);
lock.lock();

internalSend(outputMessage);
}

uint32_t Connection::getIP() {
if (ip != 1) {
return ip;
std::scoped_lock lock(connectionLock);

if (ip == 1) {
std::error_code error;
asio::ip::tcp::endpoint endpoint = socket.remote_endpoint(error);
if (error) {
g_logger().error("[Connection::getIP] - Failed to get remote endpoint: {}", error.message());
ip = 0;
} else {
ip = htonl(endpoint.address().to_v4().to_uint());
}
}

std::scoped_lock<std::recursive_mutex> lockClass(connectionLock);

// IP-address is expressed in network byte order
std::error_code error;
const asio::ip::tcp::endpoint endpoint = socket.remote_endpoint(error);
ip = error ? 0 : htonl(endpoint.address().to_v4().to_uint());
return ip;
}

void Connection::internalSend(const OutputMessage_ptr &outputMessage) {
try {
writeTimer.expires_from_now(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT));
writeTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));
writeTimer.expires_from_now(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT));
writeTimer.async_wait(std::bind(&Connection::handleTimeout, std::weak_ptr<Connection>(shared_from_this()), std::placeholders::_1));

asio::async_write(socket, asio::buffer(outputMessage->getOutputBuffer(), outputMessage->getLength()), std::bind(&Connection::onWriteOperation, shared_from_this(), std::placeholders::_1));
} catch (const std::system_error &e) {
g_logger().error("[Connection::internalSend] - error: {}", e.what());
}
asio::async_write(socket, asio::buffer(outputMessage->getOutputBuffer(), outputMessage->getLength()), std::bind(&Connection::onWriteOperation, shared_from_this(), std::placeholders::_1));
}

void Connection::onWriteOperation(const std::error_code &error) {
std::unique_lock<std::recursive_mutex> lockClass(connectionLock);
std::unique_lock lock(connectionLock);
writeTimer.cancel();
messageQueue.pop_front();

if (error) {
g_logger().error("[Connection::onWriteOperation] - Write error: {}", error.message());
messageQueue.clear();
close(FORCE_CLOSE);
return;
}

messageQueue.pop_front();

if (!messageQueue.empty()) {
const OutputMessage_ptr &outputMessage = messageQueue.front();
lockClass.unlock();
const auto &outputMessage = messageQueue.front();
lock.unlock();
protocol->onSendMessage(outputMessage);
lockClass.lock();
lock.lock();
internalSend(outputMessage);
} else if (connectionState == CONNECTION_STATE_CLOSED) {
closeSocket();
}
}

void Connection::handleTimeout(ConnectionWeak_ptr connectionWeak, const std::error_code &error) {
if (error == asio::error::operation_aborted) {
// The timer has been manually cancelled
return;
}

if (auto connection = connectionWeak.lock()) {
connection->close(FORCE_CLOSE);
if (error) {
if (error != asio::error::operation_aborted) {
g_logger().warn("[Connection::handleTimeout] - Timeout or error: {}", error.message());
if (auto connection = connectionWeak.lock()) {
connection->close(FORCE_CLOSE);
}
}
}
}
3 changes: 2 additions & 1 deletion src/server/network/connection/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ class Connection : public std::enable_shared_from_this<Connection> {
void close(bool force = false);
// Used by protocols that require server to send first
void accept(Protocol_ptr protocolPtr);
void accept(bool toggleParseHeader = true);
void acceptInternal(bool toggleParseHeader = true);

void resumeWork();

void send(const OutputMessage_ptr &outputMessage);

uint32_t getIP();
Expand Down
2 changes: 1 addition & 1 deletion src/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void ServicePort::onAccept(Connection_ptr connection, const std::error_code &err
if (service->is_single_socket()) {
connection->accept(service->make_protocol(connection));
} else {
connection->accept();
connection->acceptInternal();
}
} else {
connection->close(FORCE_CLOSE);
Expand Down
Loading