Skip to content

Commit

Permalink
rpc : prevent crashes on invalid input
Browse files Browse the repository at this point in the history
Add more checks which prevent RPC server from crashing if invalid input
is received from client
  • Loading branch information
rgerganov committed Aug 15, 2024
1 parent 98a532d commit c5657d5
Showing 1 changed file with 47 additions and 34 deletions.
81 changes: 47 additions & 34 deletions ggml/src/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of

// RPC commands
enum rpc_cmd {
ALLOC_BUFFER = 0,
GET_ALIGNMENT,
GET_MAX_SIZE,
BUFFER_GET_BASE,
FREE_BUFFER,
BUFFER_CLEAR,
SET_TENSOR,
GET_TENSOR,
COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
RPC_CMD_ALLOC_BUFFER = 0,
RPC_CMD_GET_ALIGNMENT,
RPC_CMD_GET_MAX_SIZE,
RPC_CMD_BUFFER_GET_BASE,
RPC_CMD_FREE_BUFFER,
RPC_CMD_BUFFER_CLEAR,
RPC_CMD_SET_TENSOR,
RPC_CMD_GET_TENSOR,
RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE,
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_COUNT,
};

// RPC data structures
Expand Down Expand Up @@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.empty());
delete ctx;
Expand All @@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | base_ptr (8 bytes) |
Expand Down Expand Up @@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
GGML_ASSERT(status);
}

Expand All @@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) |
Expand All @@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
GGML_ASSERT(status);
// output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1);
Expand All @@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
GGML_ASSERT(status);
}

Expand Down Expand Up @@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
Expand All @@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | alignment (8 bytes) |
Expand All @@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | max_size (8 bytes) |
Expand Down Expand Up @@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
Expand Down Expand Up @@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) |
Expand Down Expand Up @@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
if (cmd >= RPC_CMD_COUNT) {
// fail fast if the command is invalid
fprintf(stderr, "Unknown command: %d\n", cmd);
break;
}
std::vector<uint8_t> input;
std::vector<uint8_t> output;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
input.resize(input_size);
try {
input.resize(input_size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
break;
}
if (!recv_data(sockfd, input.data(), input_size)) {
break;
}
bool ok = true;
switch (cmd) {
case ALLOC_BUFFER: {
case RPC_CMD_ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output);
break;
}
case GET_ALIGNMENT: {
case RPC_CMD_GET_ALIGNMENT: {
server.get_alignment(output);
break;
}
case GET_MAX_SIZE: {
case RPC_CMD_GET_MAX_SIZE: {
server.get_max_size(output);
break;
}
case BUFFER_GET_BASE: {
case RPC_CMD_BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output);
break;
}
case FREE_BUFFER: {
case RPC_CMD_FREE_BUFFER: {
ok = server.free_buffer(input);
break;
}
case BUFFER_CLEAR: {
case RPC_CMD_BUFFER_CLEAR: {
ok = server.buffer_clear(input);
break;
}
case SET_TENSOR: {
case RPC_CMD_SET_TENSOR: {
ok = server.set_tensor(input);
break;
}
case GET_TENSOR: {
case RPC_CMD_GET_TENSOR: {
ok = server.get_tensor(input, output);
break;
}
case COPY_TENSOR: {
case RPC_CMD_COPY_TENSOR: {
ok = server.copy_tensor(input, output);
break;
}
case GRAPH_COMPUTE: {
case RPC_CMD_GRAPH_COMPUTE: {
ok = server.graph_compute(input, output);
break;
}
case GET_DEVICE_MEMORY: {
case RPC_CMD_GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
Expand Down Expand Up @@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
fflush(stdout);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
fflush(stdout);
}
#ifdef _WIN32
WSACleanup();
Expand Down

0 comments on commit c5657d5

Please sign in to comment.