diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index dac10ec36b0bd..2f42b8a9538e2 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n } // non-contiguous kernel (slow) -static __global__ void concat_f32_non_cont( +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) + concat_f32_non_cont( const char * src0, const char * src1, char * dst, @@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont( uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3, - int32_t dim) { + uint64_t nb3){ + static_assert(dim >= 0 && dim <= 3); + const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - const float * x; - for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + if constexpr (dim == 0) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + } else if constexpr (dim == 1) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + } else if constexpr (dim == 2) { + x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + } else if constexpr (dim == 3) { + x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + } } float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); - concat_f32_non_cont<<>>( - (const char *)src0->data, - (const char *)src1->data, - ( char *)dst->data, + auto launch_kernel = [&](auto dim) { + concat_f32_non_cont<<>>( + (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + }; + switch (dim) { + case 0: + launch_kernel(std::integral_constant{}); + break; + case 1: + launch_kernel(std::integral_constant{}); + break; + case 2: + launch_kernel(std::integral_constant{}); + break; + case 3: + launch_kernel(std::integral_constant{}); + break; + default: + GGML_ABORT("Invalid dim: %d", dim); + break; + } } }