diff --git a/src/server/network/protocol/protocol.cpp b/src/server/network/protocol/protocol.cpp index 4a9a79f9b35..fa5b88780af 100644 --- a/src/server/network/protocol/protocol.cpp +++ b/src/server/network/protocol/protocol.cpp @@ -125,25 +125,27 @@ void Protocol::disconnect() const { } } -void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const { +void Protocol::XTEA_transform(uint8_t* buffer, size_t messageLength, bool encrypt) const { constexpr uint32_t delta = 0x61C88647; - - // Ensure the message length is a multiple of 8 - size_t paddingBytes = outputMessage.getLength() % 8; - if (paddingBytes != 0) { - outputMessage.addPaddingBytes(8 - paddingBytes); - } - - uint8_t* buffer = outputMessage.getOutputBuffer(); - const size_t messageLength = outputMessage.getLength(); size_t readPos = 0; const std::array newKey = key; + std::array, 32> precachedControlSum; - uint32_t sum = 0; - for (size_t i = 0; i < 32; ++i) { - precachedControlSum[i][0] = sum + newKey[sum & 3]; - sum -= delta; - precachedControlSum[i][1] = sum + newKey[(sum >> 11) & 3]; + uint32_t sum = encrypt ? 0 : 0xC6EF3720; + + // Precompute control sums + if (encrypt) { + for (size_t i = 0; i < 32; ++i) { + precachedControlSum[i][0] = sum + newKey[sum & 3]; + sum -= delta; + precachedControlSum[i][1] = sum + newKey[(sum >> 11) & 3]; + } + } else { + for (size_t i = 0; i < 32; ++i) { + precachedControlSum[i][0] = sum + newKey[(sum >> 11) & 3]; + sum += delta; + precachedControlSum[i][1] = sum + newKey[sum & 3]; + } } while (readPos < messageLength) { @@ -151,25 +153,30 @@ void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const { std::ranges::copy_n(buffer + readPos, 8, tempBuffer.begin()); // Convert bytes to uint32_t considering little-endian order - std::array bytes0; - std::array bytes1; - + std::array bytes0, bytes1; std::copy_n(tempBuffer.begin(), 4, bytes0.begin()); std::copy_n(tempBuffer.begin() + 4, 4, bytes1.begin()); uint32_t vData0 = std::bit_cast(bytes0); uint32_t vData1 = std::bit_cast(bytes1); - for (size_t i = 0; i < 32; ++i) { - vData0 += ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][0]; - vData1 += ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][1]; + if (encrypt) { + for (size_t i = 0; i < 32; ++i) { + vData0 += ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][0]; + vData1 += ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][1]; + } + } else { + for (size_t i = 0; i < 32; ++i) { + vData1 -= ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][0]; + vData0 -= ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][1]; + } } // Convert vData back to bytes bytes0 = std::bit_cast>(vData0); bytes1 = std::bit_cast>(vData1); - // Copy encrypted bytes back to buffer + // Copy transformed bytes back to buffer std::copy_n(bytes0.begin(), 4, buffer + readPos); std::copy_n(bytes1.begin(), 4, buffer + readPos + 4); @@ -177,54 +184,29 @@ void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const { } } -bool Protocol::XTEA_decrypt(NetworkMessage &msg) const { +void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const { + // Ensure the message length is a multiple of 8 + size_t paddingBytes = outputMessage.getLength() % 8; + if (paddingBytes != 0) { + outputMessage.addPaddingBytes(8 - paddingBytes); + } + + uint8_t* buffer = outputMessage.getOutputBuffer(); + size_t messageLength = outputMessage.getLength(); + + XTEA_transform(buffer, messageLength, true); +} + +bool Protocol::XTEA_decrypt(NetworkMessage& msg) const { uint16_t msgLength = msg.getLength() - (checksumMethod == CHECKSUM_METHOD_NONE ? 2 : 6); if ((msgLength % 8) != 0) { return false; } - constexpr uint32_t delta = 0x61C88647; uint8_t* buffer = msg.getBuffer() + msg.getBufferPosition(); size_t messageLength = msgLength; - size_t readPos = 0; - const std::array newKey = key; // Assuming 'key' is a std::array - - std::array, 32> precachedControlSum; - uint32_t sum = 0xC6EF3720; - for (size_t i = 0; i < 32; ++i) { - precachedControlSum[i][0] = (sum + newKey[(sum >> 11) & 3]); - sum += delta; - precachedControlSum[i][1] = (sum + newKey[sum & 3]); - } - while (readPos < messageLength) { - std::array tempBuffer; - std::ranges::copy_n(buffer + readPos, 8, tempBuffer.begin()); - - std::array bytes0; - std::array bytes1; - - std::copy_n(tempBuffer.begin(), 4, bytes0.begin()); - std::copy_n(tempBuffer.begin() + 4, 4, bytes1.begin()); - - uint32_t vData0 = std::bit_cast(bytes0); - uint32_t vData1 = std::bit_cast(bytes1); - - for (size_t i = 0; i < 32; ++i) { - vData1 -= ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][0]; - vData0 -= ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][1]; - } - - // Convert vData back to bytes - bytes0 = std::bit_cast>(vData0); - bytes1 = std::bit_cast>(vData1); - - // Copy decrypted bytes back to buffer - std::copy_n(bytes0.begin(), 4, buffer + readPos); - std::copy_n(bytes1.begin(), 4, buffer + readPos + 4); - - readPos += 8; - } + XTEA_transform(buffer, messageLength, false); uint16_t innerLength = msg.get(); if (std::cmp_greater(innerLength, msgLength - 2)) { diff --git a/src/server/network/protocol/protocol.hpp b/src/server/network/protocol/protocol.hpp index ec7448782ae..86dc533ea35 100644 --- a/src/server/network/protocol/protocol.hpp +++ b/src/server/network/protocol/protocol.hpp @@ -85,6 +85,7 @@ class Protocol : public std::enable_shared_from_this { std::array buffer {}; }; + void XTEA_transform(uint8_t* buffer, size_t messageLength, bool encrypt) const; void XTEA_encrypt(OutputMessage &msg) const; bool XTEA_decrypt(NetworkMessage &msg) const; bool compression(OutputMessage &msg) const;