From 21826514dfac9237a32cad6d1f2312298800ebf9 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Mon, 27 May 2024 05:48:18 -0700 Subject: [PATCH] Allow multiple copy function pointers for CUDA graph kernel param updates CUDA graphs require parameter updates to kernels associated with GGML_OP_CPY nodes. Previously the implementation only checked for a single CUDA kernel in such nodes, but this caused a bug in cases where 2 such kernels exist. This fixes the issue by using a vector to allow multiple function pointers to be stored and checked against. Fixes #7942 --- ggml-cuda.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b82167cbf7227..2a90ee55c69a0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2510,9 +2510,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = true; bool cuda_graph_update_required = false; - // pointer to CUDA cpy kernel, which is required to identify + // vector of pointers to CUDA cpy kernels, which are required to identify // kernel parameters which need updated in the graph for each token - void * ggml_cuda_cpy_fn_ptr = nullptr; + std::vector ggml_cuda_cpy_fn_ptrs; if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { @@ -2588,9 +2588,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - if (ggml_cuda_cpy_fn_ptr == nullptr) { - // store a pointer to the copy op CUDA kernel to identify it later - ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { + ggml_cuda_cpy_fn_ptrs.push_back(ptr); } } @@ -2720,7 +2721,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured int k = 0; for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));