From 1e378947ba8081cf48178a3fbaa74ba07c92f78b Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Thu, 15 Aug 2024 14:05:19 +0300 Subject: [PATCH] rpc : prevent crashes on invalid input Add more checks which prevent RPC server from crashing if invalid input is received from client --- ggml/src/ggml-rpc.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index 7757615f5a24bd..557152b1807c3d 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -1098,13 +1098,23 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre if (!recv_data(sockfd, &cmd, 1)) { break; } + if (cmd > GET_DEVICE_MEMORY) { + // fail fast if the command is invalid + fprintf(stderr, "Unknown command: %d\n", cmd); + break; + } std::vector input; std::vector output; uint64_t input_size; if (!recv_data(sockfd, &input_size, sizeof(input_size))) { break; } - input.resize(input_size); + try { + input.resize(input_size); + } catch (const std::bad_alloc & e) { + fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size); + break; + } if (!recv_data(sockfd, input.data(), input_size)) { break; } @@ -1203,8 +1213,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free return; } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + fflush(stdout); rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); + fflush(stdout); } #ifdef _WIN32 WSACleanup();