-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster ssm scan #10558
Open
A3shTnT
wants to merge
6
commits into
ggerganov:master
Choose a base branch
from
A3shTnT:faster_ssm_scan
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Faster ssm scan #10558
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -31,6 +31,8 @@ | |||||
#include "ggml-cuda/rope.cuh" | ||||||
#include "ggml-cuda/scale.cuh" | ||||||
#include "ggml-cuda/softmax.cuh" | ||||||
#include "ggml-cuda/ssm_conv.cuh" | ||||||
#include "ggml-cuda/ssm_scan.cuh" | ||||||
#include "ggml-cuda/sum.cuh" | ||||||
#include "ggml-cuda/sumrows.cuh" | ||||||
#include "ggml-cuda/tsembd.cuh" | ||||||
|
@@ -2155,6 +2157,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | |||||
case GGML_OP_SUM_ROWS: | ||||||
ggml_cuda_op_sum_rows(ctx, dst); | ||||||
break; | ||||||
case GGML_OP_SSM_CONV: | ||||||
ggml_cuda_op_ssm_conv(ctx, dst); | ||||||
break; | ||||||
case GGML_OP_SSM_SCAN: | ||||||
ggml_cuda_op_ssm_scan(ctx, dst); | ||||||
break; | ||||||
case GGML_OP_ARGSORT: | ||||||
ggml_cuda_op_argsort(ctx, dst); | ||||||
break; | ||||||
|
@@ -2989,7 +2997,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||
case GGML_OP_SIN: | ||||||
case GGML_OP_COS: | ||||||
case GGML_OP_CLAMP: | ||||||
return true; | ||||||
case GGML_OP_SSM_SCAN: | ||||||
case GGML_OP_SSM_CONV: | ||||||
return true; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
case GGML_OP_CONT: | ||||||
return op->src[0]->type != GGML_TYPE_BF16; | ||||||
case GGML_OP_DIAG_MASK_INF: | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#include "ssm_conv.cuh" | ||
|
||
template <int block_size> | ||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, | ||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
const int nc, const int ncs, const int nr, const int n_t, const int n_s) { | ||
const int tid = blockIdx.y; | ||
const int i3 = blockIdx.x; | ||
const int i2 = threadIdx.x; | ||
|
||
const int ith = tid; | ||
const int nth = WARP_SIZE; | ||
|
||
// rows per thread | ||
const int dr = (nr + nth - 1) / nth; | ||
|
||
// row range for this thread | ||
const int ir0 = dr * ith; | ||
const int ir1 = min(ir0 + dr, nr); | ||
const int ir = ir1 - ir0; | ||
|
||
// {d_conv - 1 + n_t, d_inner, n_seqs} | ||
// sliding window | ||
const float * s = (const float *) ((const char *) src0 + ir0 * src0_nb1 + i2 * src0_nb0 + | ||
i3 * src0_nb2); // {d_conv, d_inner, n_s} | ||
const float * c = (const float *) ((const char *) src1 + ir0 * src1_nb1); // {d_conv, d_inner} | ||
float * x = (float *) ((char *) dst + ir0 * dst_nb0 + i2 * dst_nb1 + i3 * dst_nb2); // {d_inner, n_t, n_s} | ||
|
||
// TODO: transpose the output for smaller strides for big batches? | ||
// d_inner | ||
for (int i1 = 0; i1 < ir; ++i1) { | ||
// rowwise dot product | ||
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision | ||
float sumf = 0.0f; | ||
|
||
// d_conv | ||
#pragma unroll | ||
for (int i0 = 0; i0 < nc; ++i0) { | ||
sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc]; | ||
} | ||
x[i1] = sumf; | ||
} | ||
} | ||
|
||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, | ||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, | ||
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, | ||
const int n_s, cudaStream_t stream) { | ||
const dim3 block_dims(n_t, 1, 1); | ||
// const int nblocks = n_s; // TODO | ||
const dim3 grid_dims(n_s, WARP_SIZE, 1); | ||
|
||
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>( | ||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); | ||
} | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
|
||
const int nc = src1->ne[0]; // d_conv | ||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t | ||
const int nr = src0->ne[1]; // d_inner | ||
const int n_t = dst->ne[1]; // tokens per sequence | ||
const int n_s = dst->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT(dst->ne[0] == nr); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
|
||
const float * src0_d = (const float *) src0->data; | ||
const float * src1_d = (const float *) src1->data; | ||
float * dst_d = (float *) dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], | ||
dst->nb[2], nc, ncs, nr, n_t, n_s, stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#include "ssm_scan.cuh" | ||
|
||
// #include <cuda_runtime.h> | ||
// static __device__ void global_to_shared(const float *src, float *dst) { | ||
// asm volatile("cp.async."); | ||
// } | ||
|
||
template <size_t splitD, size_t N> | ||
__global__ void __launch_bounds__(splitD, 2) | ||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, | ||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, | ||
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * __restrict__ dst, const int D, const int L, const int B) { | ||
const int bidx = blockIdx.x; // split along B | ||
const int bidy = blockIdx.y; // split along D | ||
const int tid = threadIdx.x; | ||
const int wid = tid / 32; | ||
const int wtid = tid % 32; | ||
|
||
extern __shared__ float smem[]; | ||
const int stride_sA = N + 1; | ||
const int stride_ss0 = N + 1; | ||
float * smem_A = smem; | ||
float * smem_s0 = smem_A + splitD * stride_sA; | ||
|
||
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); | ||
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); | ||
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); | ||
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); | ||
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
|
||
const int stride_s0 = src0_nb1 / sizeof(float); | ||
const int stride_x = src1_nb1 / sizeof(float); | ||
const int stride_dt = src2_nb1 / sizeof(float); | ||
const int stride_A = src3_nb1 / sizeof(float); | ||
const int stride_B = src4_nb1 / sizeof(float); | ||
const int stride_C = src5_nb1 / sizeof(float); | ||
const int stride_s = stride_s0; | ||
const int stride_y = stride_x; | ||
|
||
// can N not be 16? for example 32? | ||
if (N == 16) { | ||
#pragma unroll | ||
for (int i = 0; i < splitD / 4; i += 2) { | ||
float value = A_block[(wid * warpSize + i) * stride_A + wtid]; | ||
// todo: bank conflict | ||
// I am always confused with how to use the swizzling method to solve | ||
// bank conflit. Hoping somebody can tell me. | ||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
} | ||
#pragma unroll | ||
for (int i = 0; i < splitD / 4; i += 2) { | ||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; | ||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
} | ||
} | ||
|
||
__syncthreads(); | ||
|
||
for (int i = 0; i < L; i++) { | ||
float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid]; | ||
if (dt_soft_plus <= 20.0f) { | ||
dt_soft_plus = log1pf(exp(dt_soft_plus)); | ||
} | ||
float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus; | ||
float sumf = 0.0f; | ||
#pragma unroll | ||
for (int j = 0; j < N; j++) { | ||
float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] * | ||
expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) + | ||
(B_block[i * stride_B + j] * x_dt); | ||
sumf += state * C_block[i * stride_C + j]; | ||
if (i == L - 1) { | ||
s_block[(wid * warpSize + wtid) * stride_s + j] = state; | ||
} else { | ||
smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state; | ||
} | ||
} | ||
__syncthreads(); | ||
y_block[i * stride_y + wid * warpSize + wtid] = sumf; | ||
} | ||
} | ||
|
||
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, | ||
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { | ||
const int threads = 128; | ||
// todo: consider D cannot be divided,does this situation exist? | ||
GGML_ASSERT(D % threads == 0); | ||
const dim3 blocks(B, (D + threads - 1) / threads, 1); | ||
const int smem_size = (threads * (N + 1) * 2) * sizeof(float); | ||
if (N == 16) { | ||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( | ||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, | ||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); | ||
} else { | ||
GGML_ABORT("doesn't support N!=16."); | ||
} | ||
} | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // s | ||
const struct ggml_tensor * src1 = dst->src[1]; // x | ||
const struct ggml_tensor * src2 = dst->src[2]; // dt | ||
const struct ggml_tensor * src3 = dst->src[3]; // A | ||
const struct ggml_tensor * src4 = dst->src[4]; // B | ||
const struct ggml_tensor * src5 = dst->src[5]; // C | ||
|
||
// const int64_t d_state = src0->ne[0]; | ||
// const int64_t d_inner = src0->ne[1]; | ||
// const int64_t l = src1->ne[1]; | ||
// const int64_t b = src0->ne[2]; | ||
|
||
const int64_t nc = src0->ne[0]; // d_state | ||
const int64_t nr = src0->ne[1]; // d_inner | ||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence | ||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src2->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src3->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src4->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src5->nb[0] == sizeof(float)); | ||
// required for the dot product between s and C | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
// required for per-sequence offsets for states | ||
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); | ||
// required to get correct offset for state destination (i.e. src1->nb[3]) | ||
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); | ||
|
||
const float * src0_d = (const float *) src0->data; | ||
const float * src1_d = (const float *) src1->data; | ||
const float * src2_d = (const float *) src2->data; | ||
const float * src3_d = (const float *) src3->data; | ||
const float * src4_d = (const float *) src4->data; | ||
const float * src5_d = (const float *) src5->data; | ||
float * dst_d = (float *) dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], | ||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], | ||
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.