diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 0211347fba03..1e100a626935 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -133,6 +133,34 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len return to_read; } +static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms) +{ + int total_read = 0; + int len = requested_len; + + while (len > 0) { + int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms); + + if (bytes_read < 0) { + return bytes_read; // Return error from the underlying read + } + + if (bytes_read == 0) { + // If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation + ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read); + return -1; + } + + // Update buffer and remaining length + buffer += bytes_read; + len -= bytes_read; + total_read += bytes_read; + + ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read); + } + return total_read; +} + static char *trimwhitespace(char *str) { char *end; @@ -486,7 +514,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t // Receive and process header first (based on header size) int header = 2; int mask_len = 4; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -500,7 +528,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -508,7 +536,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t } else if (payload_len == 127) { // headerLen += 8; header = 8; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -523,7 +551,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t if (mask) { // Read and store mask - if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) { + if (payload_len != 0 && (rlen = esp_transport_read_exact_size(ws, buffer, mask_len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; }