Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: change from memset/memcpy to modern cpp ranges #2989

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/creatures/combat/condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,13 @@ void ConditionAttributes::addCondition(std::shared_ptr<Creature> creature, const
endCondition(creature);

// Apply the new one
memcpy(skills, conditionAttrs->skills, sizeof(skills));
memcpy(skillsPercent, conditionAttrs->skillsPercent, sizeof(skillsPercent));
memcpy(stats, conditionAttrs->stats, sizeof(stats));
memcpy(statsPercent, conditionAttrs->statsPercent, sizeof(statsPercent));
memcpy(buffs, conditionAttrs->buffs, sizeof(buffs));
memcpy(buffsPercent, conditionAttrs->buffsPercent, sizeof(buffsPercent));

// Using std::array can only increment to the new instead of use memcpy
std::ranges::copy(std::span(conditionAttrs->skills), skills);
std::ranges::copy(std::span(conditionAttrs->skillsPercent), skillsPercent);
std::ranges::copy(std::span(conditionAttrs->stats), stats);
std::ranges::copy(std::span(conditionAttrs->statsPercent), statsPercent);
std::ranges::copy(std::span(conditionAttrs->buffs), buffs);
std::ranges::copy(std::span(conditionAttrs->buffsPercent), buffsPercent);

absorbs = conditionAttrs->absorbs;
absorbsPercent = conditionAttrs->absorbsPercent;
increases = conditionAttrs->increases;
Expand Down
4 changes: 2 additions & 2 deletions src/creatures/creature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ void Creature::onCreatureMove(const std::shared_ptr<Creature> &creature, const s
if (oldPos.y > newPos.y) { // north
// shift y south
for (int32_t y = mapWalkHeight - 1; --y >= 0;) {
memcpy(localMapCache[y + 1], localMapCache[y], sizeof(localMapCache[y]));
std::ranges::copy(std::span(localMapCache[y]), localMapCache[y + 1]);
}

// update 0
Expand All @@ -525,7 +525,7 @@ void Creature::onCreatureMove(const std::shared_ptr<Creature> &creature, const s
} else if (oldPos.y < newPos.y) { // south
// shift y north
for (int32_t y = 0; y <= mapWalkHeight - 2; ++y) {
memcpy(localMapCache[y], localMapCache[y + 1], sizeof(localMapCache[y]));
std::ranges::copy(std::span(localMapCache[y + 1]), localMapCache[y]);
}

// update mapWalkHeight - 1
Expand Down
37 changes: 19 additions & 18 deletions src/security/rsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ void RSA::decrypt(char* msg) const {
// m = c^d mod n
mpz_powm(m, c, d, n);

size_t count = (mpz_sizeinbase(m, 2) + 7) / 8;
memset(msg, 0, 128 - count);
const size_t count = (mpz_sizeinbase(m, 2) + 7) / 8;
std::fill(msg, msg + (128 - count), 0);

mpz_export(msg + (128 - count), nullptr, 1, 1, 0, 0, m);

mpz_clear(c);
Expand Down Expand Up @@ -159,27 +160,27 @@ enum {
};

uint16_t RSA::decodeLength(char*&pos) const {
uint8_t buffer[4] = { 0 };
auto length = static_cast<uint16_t>(static_cast<uint8_t>(*pos++));
std::array<uint8_t, 4> buffer = { 0 };
uint16_t length = static_cast<uint8_t>(*pos++);
if (length & 0x80) {
length &= 0x7F;
if (length > 4) {
uint8_t numLengthBytes = length & 0x7F;
if (numLengthBytes > 4) {
g_logger().error("[RSA::loadPEM] - Invalid 'length'");
return 0;
}
switch (length) {
case 4:
buffer[3] = static_cast<uint8_t>(*pos++);
case 3:
buffer[2] = static_cast<uint8_t>(*pos++);
case 2:
buffer[1] = static_cast<uint8_t>(*pos++);
case 1:
buffer[0] = static_cast<uint8_t>(*pos++);
default:
break;
// Copy 'numLengthBytes' bytes from 'pos' into 'buffer', starting at the correct position
std::ranges::copy_n(pos, numLengthBytes, buffer.begin() + (4 - numLengthBytes));
pos += numLengthBytes;
// Reconstruct 'length' from 'buffer' (big-endian)
uint32_t tempLength = 0;
for (size_t i = 0; i < numLengthBytes; ++i) {
tempLength = (tempLength << 8) | buffer[4 - numLengthBytes + i];
}
if (tempLength > UINT16_MAX) {
g_logger().error("[RSA::loadPEM] - Length too large");
return 0;
}
std::memcpy(&length, buffer, sizeof(length));
length = static_cast<uint16_t>(tempLength);
}
return length;
}
Expand Down
117 changes: 68 additions & 49 deletions src/server/network/protocol/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,69 +125,88 @@ void Protocol::disconnect() const {
}
}

void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const {
const uint32_t delta = 0x61C88647;

// The message must be a multiple of 8
size_t paddingBytes = outputMessage.getLength() & 7;
if (paddingBytes != 0) {
outputMessage.addPaddingBytes(8 - paddingBytes);
void Protocol::XTEA_transform(uint8_t* buffer, size_t messageLength, bool encrypt) const {
constexpr uint32_t delta = 0x61C88647;
size_t readPos = 0;
const std::array<uint32_t, 4> newKey = key;

std::array<std::array<uint32_t, 2>, 32> precachedControlSum;
uint32_t sum = encrypt ? 0 : 0xC6EF3720;

// Precompute control sums
if (encrypt) {
for (size_t i = 0; i < 32; ++i) {
precachedControlSum[i][0] = sum + newKey[sum & 3];
sum -= delta;
precachedControlSum[i][1] = sum + newKey[(sum >> 11) & 3];
}
} else {
for (size_t i = 0; i < 32; ++i) {
precachedControlSum[i][0] = sum + newKey[(sum >> 11) & 3];
sum += delta;
precachedControlSum[i][1] = sum + newKey[sum & 3];
}
}

uint8_t* buffer = outputMessage.getOutputBuffer();
auto messageLength = static_cast<int32_t>(outputMessage.getLength());
int32_t readPos = 0;
const std::array<uint32_t, 4> newKey = { key[0], key[1], key[2], key[3] };
// TODO: refactor this for not use c-style
uint32_t precachedControlSum[32][2];
uint32_t sum = 0;
for (int32_t i = 0; i < 32; ++i) {
precachedControlSum[i][0] = (sum + newKey[sum & 3]);
sum -= delta;
precachedControlSum[i][1] = (sum + newKey[(sum >> 11) & 3]);
}
while (readPos < messageLength) {
std::array<uint32_t, 2> vData = {};
memcpy(vData.data(), buffer + readPos, 8);
for (int32_t i = 0; i < 32; ++i) {
vData[0] += ((vData[1] << 4 ^ vData[1] >> 5) + vData[1]) ^ precachedControlSum[i][0];
vData[1] += ((vData[0] << 4 ^ vData[0] >> 5) + vData[0]) ^ precachedControlSum[i][1];
std::array<uint8_t, 8> tempBuffer;
std::ranges::copy_n(buffer + readPos, 8, tempBuffer.begin());

// Convert bytes to uint32_t considering little-endian order
std::array<uint8_t, 4> bytes0, bytes1;
std::copy_n(tempBuffer.begin(), 4, bytes0.begin());
std::copy_n(tempBuffer.begin() + 4, 4, bytes1.begin());

uint32_t vData0 = std::bit_cast<uint32_t>(bytes0);
uint32_t vData1 = std::bit_cast<uint32_t>(bytes1);

if (encrypt) {
for (size_t i = 0; i < 32; ++i) {
vData0 += ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][0];
vData1 += ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][1];
}
} else {
for (size_t i = 0; i < 32; ++i) {
vData1 -= ((vData0 << 4 ^ vData0 >> 5) + vData0) ^ precachedControlSum[i][0];
vData0 -= ((vData1 << 4 ^ vData1 >> 5) + vData1) ^ precachedControlSum[i][1];
}
}
memcpy(buffer + readPos, vData.data(), 8);

// Convert vData back to bytes
bytes0 = std::bit_cast<std::array<uint8_t, 4>>(vData0);
bytes1 = std::bit_cast<std::array<uint8_t, 4>>(vData1);

// Copy transformed bytes back to buffer
std::copy_n(bytes0.begin(), 4, buffer + readPos);
std::copy_n(bytes1.begin(), 4, buffer + readPos + 4);

readPos += 8;
}
}

void Protocol::XTEA_encrypt(OutputMessage &outputMessage) const {
// Ensure the message length is a multiple of 8
size_t paddingBytes = outputMessage.getLength() % 8;
if (paddingBytes != 0) {
outputMessage.addPaddingBytes(8 - paddingBytes);
}

uint8_t* buffer = outputMessage.getOutputBuffer();
size_t messageLength = outputMessage.getLength();

XTEA_transform(buffer, messageLength, true);
}

bool Protocol::XTEA_decrypt(NetworkMessage &msg) const {
uint16_t msgLength = msg.getLength() - (checksumMethod == CHECKSUM_METHOD_NONE ? 2 : 6);
if ((msgLength & 7) != 0) {
if ((msgLength % 8) != 0) {
return false;
}

const uint32_t delta = 0x61C88647;

uint8_t* buffer = msg.getBuffer() + msg.getBufferPosition();
auto messageLength = static_cast<int32_t>(msgLength);
int32_t readPos = 0;
const std::array<uint32_t, 4> newKey = { key[0], key[1], key[2], key[3] };
// TODO: refactor this for not use c-style
uint32_t precachedControlSum[32][2];
uint32_t sum = 0xC6EF3720;
for (int32_t i = 0; i < 32; ++i) {
precachedControlSum[i][0] = (sum + newKey[(sum >> 11) & 3]);
sum += delta;
precachedControlSum[i][1] = (sum + newKey[sum & 3]);
}
while (readPos < messageLength) {
std::array<uint32_t, 2> vData = {};
memcpy(vData.data(), buffer + readPos, 8);
for (int32_t i = 0; i < 32; ++i) {
vData[1] -= ((vData[0] << 4 ^ vData[0] >> 5) + vData[0]) ^ precachedControlSum[i][0];
vData[0] -= ((vData[1] << 4 ^ vData[1] >> 5) + vData[1]) ^ precachedControlSum[i][1];
}
memcpy(buffer + readPos, vData.data(), 8);
readPos += 8;
}
size_t messageLength = msgLength;

XTEA_transform(buffer, messageLength, false);

uint16_t innerLength = msg.get<uint16_t>();
if (std::cmp_greater(innerLength, msgLength - 2)) {
Expand Down
3 changes: 2 additions & 1 deletion src/server/network/protocol/protocol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Protocol : public std::enable_shared_from_this<Protocol> {
encryptionEnabled = true;
}
void setXTEAKey(const uint32_t* newKey) {
memcpy(this->key.data(), newKey, sizeof(*newKey) * 4);
std::ranges::copy(newKey, newKey + 4, this->key.begin());
}
void setChecksumMethod(ChecksumMethods_t method) {
checksumMethod = method;
Expand All @@ -85,6 +85,7 @@ class Protocol : public std::enable_shared_from_this<Protocol> {
std::array<char, NETWORKMESSAGE_MAXSIZE> buffer {};
};

void XTEA_transform(uint8_t* buffer, size_t messageLength, bool encrypt) const;
void XTEA_encrypt(OutputMessage &msg) const;
bool XTEA_decrypt(NetworkMessage &msg) const;
bool compression(OutputMessage &msg) const;
Expand Down