Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Add safety checks around missed allocations for AsyncWebSocketMessageBuffer #6

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,21 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer()
AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t* data, size_t size)
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
{
std::memcpy(_buffer->data(), data, size);
if (_buffer->capacity() < size) {
_buffer.reset();
_buffer = std::make_shared<std::vector<uint8_t>>(0);
} else {
std::memcpy(_buffer->data(), data, size);
}
}

AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size)
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
{
if (_buffer->capacity() < size) {
_buffer.reset();
_buffer = std::make_shared<std::vector<uint8_t>>(0);
}
}

AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer()
Expand Down Expand Up @@ -443,6 +452,9 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
if (!_client)
return;

if (buffer->size() == 0)
return;

{
AsyncWebLockGuard l(_lock);
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
Expand Down Expand Up @@ -687,8 +699,10 @@ std::shared_ptr<std::vector<uint8_t>> makeSharedBuffer(const uint8_t *message, s

void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer)
{
text(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
text(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocketClient::text(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -739,8 +753,10 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data)

void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer)
{
binary(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
binary(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocketClient::binary(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -936,8 +952,10 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data)

void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer)
{
textAll(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
textAll(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocket::textAll(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -1014,8 +1032,10 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t

void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer)
{
binaryAll(std::move(buffer->_buffer));
delete buffer;
if (buffer) {
binaryAll(std::move(buffer->_buffer));
delete buffer;
}
}

void AsyncWebSocket::binaryAll(std::shared_ptr<std::vector<uint8_t>> buffer)
Expand Down Expand Up @@ -1200,12 +1220,26 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size)
{
return new AsyncWebSocketMessageBuffer(size);
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(size);
if (buffer->length() != size)
{
delete buffer;
return nullptr;
} else {
return buffer;
}
}

AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size)
{
return new AsyncWebSocketMessageBuffer(data, size);
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(data, size);
if (buffer->length() != size)
{
delete buffer;
return nullptr;
} else {
return buffer;
}
}

/*
Expand Down
Loading