Skip to content

Commit

Permalink
Update CUDA graph on scale change plus clear nodes/params (ggerganov#…
Browse files Browse the repository at this point in the history
…9550)

* Avoid using saved CUDA graph if scale changes and reset nodes/params on update

Fixes ggerganov#9451

* clear before resize
  • Loading branch information
agray3 authored Sep 21, 2024
1 parent e948a7d commit 41f4778
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}

static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
Expand Down Expand Up @@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}

if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}

return true;
}

Expand Down Expand Up @@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};

struct ggml_cuda_graph {
Expand Down

0 comments on commit 41f4778

Please sign in to comment.