From a3d48e448a9933c14d047cab36b86d0356440e80 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Thu, 1 Aug 2024 03:35:05 -0700 Subject: [PATCH 1/2] Simplify and improve CUDA graphs through use of indirect copy pointers Previously there was complexity in the CUDA graphs implementation due frequently changing parameters to copy kernels associated with K and V cache pointers. This patch simplifies by using indirection to avoid such parameters frequently changing, avoiding the need for frequent graph updates. --- ggml/include/ggml-backend.h | 3 + ggml/src/ggml-cuda.cu | 83 ++++----------------- ggml/src/ggml-cuda/common.cuh | 4 - ggml/src/ggml-cuda/cpy.cu | 133 ++++++++++++++++++---------------- src/llama.cpp | 30 ++++++++ 5 files changed, 118 insertions(+), 135 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 5f3f1e286990e..40a4196a51526 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -232,6 +232,9 @@ extern "C" { GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); + // Copy K and V cache pointers to backend + GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size); + GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 682c30d45bcf4..58c4eb3229145 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2479,9 +2479,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = true; bool cuda_graph_update_required = false; - // vector of pointers to CUDA cpy kernels, which are required to identify - // kernel parameters which need updated in the graph for each token - std::vector ggml_cuda_cpy_fn_ptrs; if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { @@ -2527,7 +2524,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2554,16 +2550,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #endif } - if (node->op == GGML_OP_CPY) { - // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { - ggml_cuda_cpy_fn_ptrs.push_back(ptr); - } - } - if (!use_cuda_graph) { break; } @@ -2653,64 +2639,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } - // Perform update to graph (if required for this token), and change copy parameter (required for every token) - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } - - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } - - // Update graph executable - cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); - if (stat == cudaErrorGraphExecUpdateFailure) { + // Update graph executable + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG - GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__); + GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__); #endif - // The pre-existing graph exec cannot be updated due to violated constraints - // so instead clear error and re-instantiate - cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); - } else { - GGML_ASSERT(stat == cudaSuccess); + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eb39b6d23a6b3..41f49bb202348 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -583,15 +583,11 @@ struct ggml_cuda_graph { } cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; - size_t num_nodes = 0; - std::vector nodes; - std::vector params; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; std::vector ggml_graph_properties; - std::vector updated_kernel_arg; #endif }; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index aad34bfe5b32b..c8090f5025da2 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -31,16 +31,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int layer_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[layer_index]: cdst_direct; + // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -263,16 +265,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int layer_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[layer_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -288,110 +292,128 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, cpy_blck(cx + x_offset, cdst + dst_offset); } +static char ** k_cache_ptrs; +static char ** v_cache_ptrs; + +static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char ** host_cache_ptrs, size_t size) { + if(backend_cache_ptrs == nullptr) { + cudaMalloc(&backend_cache_ptrs, size*sizeof(char *)); + } + cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice); +} + +void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) { + ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size); +} + +void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) { + ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size); +} + static void ggml_cpy_f16_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } static void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index); } void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { @@ -428,26 +450,41 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; + // If this copy is associated with K or V caches, then use indirection of destination pointers to avoid + // the kernel params changing for each token and hence the need for frequent compute graph updates. + char ** dest_indirect = nullptr; + int layer_index=-1; + const char* k_prefix = "k_cache_view-"; + if (strncmp(src1->name, k_prefix, strlen(k_prefix)) == 0) { + dest_indirect = k_cache_ptrs; + layer_index = atoi(src1->name + strlen(k_prefix)); + } + const char* v_prefix = "v_cache_view-"; + if (strncmp(src1->name, v_prefix, strlen(v_prefix)) == 0) { + dest_indirect = v_cache_ptrs; + layer_index = atoi(src1->name + strlen(v_prefix)); + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -459,31 +496,3 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; ggml_cuda_cpy(ctx, src0, dst); } - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } -} diff --git a/src/llama.cpp b/src/llama.cpp index 7f2f0003142a3..9dace659e1d35 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14734,6 +14734,36 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); +#ifdef GGML_USE_CUDA + // Copy K and V cache pointers to backend + + // Stage pointers for each layer in host vectors + std::vector k_cache_ptrs; + std::vector v_cache_ptrs; + const int64_t n_layer = model.hparams.n_layer; + const int64_t kv_head = kv_self.head; + for (int il = 0; il < n_layer; ++il) { + // K cache pointer for this layer + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + ggml_tensor * tmp_tensor = kv_self.k_l[il]; + size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; + k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + // V cache pointer for this layer + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + tmp_tensor = kv_self.v_l[il]; + if (cparams.flash_attn) { + tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + } else { + tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); + } + v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + } + + // copy host vector data to backend + ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size()); + ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size()); +#endif + llama_set_inputs(lctx, u_batch); llama_graph_compute(lctx, gf, n_threads); From 38f4863a2453d2e24cd2cda797e18af8f0bb3781 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 14 Aug 2024 01:38:00 -0700 Subject: [PATCH 2/2] Abstract into GGML --- ggml/include/ggml-backend.h | 3 +-- ggml/src/ggml-cuda/cpy.cu | 26 ++++++++++++++++++++------ src/llama.cpp | 28 +--------------------------- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 40a4196a51526..95c4f584ec4a3 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -233,8 +233,7 @@ extern "C" { GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); // Copy K and V cache pointers to backend - GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size); - GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size); + GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c8090f5025da2..613a98bf901e0 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -302,12 +302,26 @@ static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice); } -void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) { - ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size); -} - -void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) { - ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size); +void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa,const int64_t n_embd_v_gqa, const bool flash_attn) { + + std::vector host_k_cache_ptrs; + std::vector host_v_cache_ptrs; + for (int il = 0; il < n_layer; ++il) { + // K cache pointer for this layer + ggml_tensor * tmp_tensor = kv_kl[il]; + size_t tmp_offset = (ggml_row_size(kv_kl[il]->type, n_embd_k_gqa))*kv_head; + host_k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + // V cache pointer for this layer + tmp_tensor = kv_vl[il]; + if (flash_attn) { + tmp_offset = (kv_head)*ggml_row_size(kv_vl[il]->type, n_embd_v_gqa); + } else { + tmp_offset = (kv_head)*ggml_element_size(kv_vl[il]); + } + host_v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + } + ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_k_cache_ptrs.data(), host_k_cache_ptrs.size()); + ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_v_cache_ptrs.data(), host_v_cache_ptrs.size()); } static void ggml_cpy_f16_f32_cuda( diff --git a/src/llama.cpp b/src/llama.cpp index 9dace659e1d35..d4875a84f320a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14735,33 +14735,7 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); #ifdef GGML_USE_CUDA - // Copy K and V cache pointers to backend - - // Stage pointers for each layer in host vectors - std::vector k_cache_ptrs; - std::vector v_cache_ptrs; - const int64_t n_layer = model.hparams.n_layer; - const int64_t kv_head = kv_self.head; - for (int il = 0; il < n_layer; ++il) { - // K cache pointer for this layer - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - ggml_tensor * tmp_tensor = kv_self.k_l[il]; - size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; - k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); - // V cache pointer for this layer - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - tmp_tensor = kv_self.v_l[il]; - if (cparams.flash_attn) { - tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); - } else { - tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); - } - v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); - } - - // copy host vector data to backend - ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size()); - ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size()); + ggml_backend_copy_kv_cache_ptrs(model.hparams.n_layer, kv_self.head, kv_self.k_l.data(), kv_self.v_l.data(), hparams.n_embd_k_gqa(), hparams.n_embd_v_gqa(), cparams.flash_attn); #endif llama_set_inputs(lctx, u_batch);