From 0dd48a69528ab696182be924502b03e54509afa8 Mon Sep 17 00:00:00 2001 From: lihan <1091770049@qq.com> Date: Thu, 5 Dec 2024 10:05:32 +0800 Subject: [PATCH] faster ssm conv implementatioin --- ggml/src/ggml-cuda/ssm_conv.cu | 143 ++++++++++++++++++++++++--------- 1 file changed, 106 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm_conv.cu b/ggml/src/ggml-cuda/ssm_conv.cu index ca0089cd38bd3..205344d3faaac 100644 --- a/ggml/src/ggml-cuda/ssm_conv.cu +++ b/ggml/src/ggml-cuda/ssm_conv.cu @@ -1,45 +1,97 @@ #include "ssm_conv.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, const int n_s) { - const int tid = blockIdx.y; - const int i3 = blockIdx.x; - const int i2 = threadIdx.x; - - const int ith = tid; - const int nth = WARP_SIZE; - - // rows per thread - const int dr = (nr + nth - 1) / nth; - - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = min(ir0 + dr, nr); - const int ir = ir1 - ir0; - - // {d_conv - 1 + n_t, d_inner, n_seqs} - // sliding window - const float * s = (const float *) ((const char *) src0 + ir0 * src0_nb1 + i2 * src0_nb0 + - i3 * src0_nb2); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1 + ir0 * src1_nb1); // {d_conv, d_inner} - float * x = (float *) ((char *) dst + ir0 * dst_nb0 + i2 * dst_nb1 + i3 * dst_nb2); // {d_inner, n_t, n_s} - - // TODO: transpose the output for smaller strides for big batches? - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + + const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); + const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (int j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + + for (int i = 0; i < n_t; i++) { float sumf = 0.0f; -// d_conv + if (i == 0) { + for (int j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + #pragma unroll - for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc]; + for (int j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; + } +} + +template +static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * __restrict__ dst, const int dst_nb0, + const int dst_nb1, const int dst_nb2, const int nc, const int ncs, + const int nr, const int n_t, const int n_s) { + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + const int bidz = blockIdx.z; + + const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + + bidz * split_n_t * src0_nb0); + const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = + (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (int j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + +#pragma unroll + for (int i = 0; i < split_n_t; i++) { + if (bidz * split_n_t + i < n_t) { + float sumf = 0.0f; + + if (i == 0) { + for (int j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + +#pragma unroll + for (int j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; } - x[i1] = sumf; } } @@ -47,12 +99,29 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, const int n_s, cudaStream_t stream) { - const dim3 block_dims(n_t, 1, 1); - // const int nblocks = n_s; // TODO - const dim3 grid_dims(n_s, WARP_SIZE, 1); + const int threads = 128; + GGML_ASSERT(nr % threads == 0); - ssm_conv_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); + if (n_t <= 32) { + const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); + if (nc == 4) { + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, + n_s); + } else { + GGML_ABORT("Only support kernel size = 4 now."); + } + } else { + if (nc == 4) { + const int split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32 + <<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, + dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); + } else { + GGML_ABORT("Only support kernel size = 4 right now."); + } + } } void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {