-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml:Mamba Cuda kernel performance improve #9186
Closed
Closed
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
f809568
Add initial/naive CUDA kernels for the GGML_OP_SSM_CONV and GGML_OP_S…
jploski cc365b0
Add GGML_OP_SSM_CONF, GGML_OP_SSM_SCAN to supported ops for CUDA back…
jploski 25f9e65
Update CUDA ops ssm_conv and ssm_scan to match CPU implementation fro…
jploski 64fbd32
Add patch to test cases provided by @compilade; test for ssm_conv fails
jploski 12c913c
Fix backend test for ssm_conv CUDA op not working
jploski 061e520
Update CUDA ops and tests to match implementation from commit 8fb57ac…
jploski fae826f
Fix failed assertions while running Falcon Mamba
jploski 20d390b
10x performance improve 4 cuda ssm conv & scan
piDack 8dd323b
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack b423a6d
fix ssm_scan numerical error & others update
piDack 40f4787
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack 1928967
resolve test-backend-ops conflicts
piDack 21c16fa
fix trailing whitespace
piDack e53b14f
del debug ingo
piDack eec0e8c
memory access pattern
piDack 0e682ce
add restrict
piDack 5999d6d
fix conflicts
piDack 316a049
add restrict for dst
piDack 99f2ac1
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack 63b6e73
recommit for ci pass
piDack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#include "ssm_conv.cuh" | ||
|
||
template <int block_size> | ||
static __global__ void ssm_conv_f32( | ||
const float * __restrict__ src0, const float * __restrict__ src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb1, | ||
float * __restrict__ dst, | ||
const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
const int nc, const int ncs, const int nr, const int n_t, const int n_s) { | ||
|
||
const int tid = blockIdx.y; | ||
const int i3 = blockIdx.x; | ||
const int i2 = threadIdx.x; | ||
|
||
const int ith = tid; | ||
const int nth = WARP_SIZE; | ||
|
||
// rows per thread | ||
const int dr = (nr + nth - 1)/nth; | ||
|
||
// row range for this thread | ||
const int ir0 = dr*ith; | ||
const int ir1 = min(ir0 + dr, nr); | ||
const int ir = ir1 - ir0; | ||
|
||
// {d_conv - 1 + n_t, d_inner, n_seqs} | ||
// sliding window | ||
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s} | ||
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner} | ||
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} | ||
|
||
// TODO: transpose the output for smaller strides for big batches? | ||
// d_inner | ||
for (int i1 = 0; i1 < ir; ++i1) { | ||
// rowwise dot product | ||
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision | ||
float sumf = 0.0f; | ||
|
||
// d_conv | ||
#pragma unroll | ||
for (int i0 = 0; i0 < nc; ++i0) { | ||
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; | ||
} | ||
x[i1] = sumf; | ||
} | ||
} | ||
|
||
static void ssm_conv_f32_cuda( | ||
const float * src0, const float * src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb1, | ||
float * dst, | ||
const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
const int nc, const int ncs, const int nr, const int n_t, const int n_s, | ||
cudaStream_t stream) { | ||
|
||
const dim3 block_dims(n_t, 1, 1); | ||
//const int nblocks = n_s; // TODO | ||
const dim3 grid_dims(n_s, WARP_SIZE, 1); | ||
|
||
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>( | ||
src0, src1, | ||
src0_nb0, src0_nb1, src0_nb2, | ||
src1_nb1, | ||
dst, | ||
dst_nb0, dst_nb1, dst_nb2, | ||
nc, ncs, nr, n_t, n_s); | ||
} | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
|
||
const int nc = src1->ne[0]; // d_conv | ||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t | ||
const int nr = src0->ne[1]; // d_inner | ||
const int n_t = dst->ne[1]; // tokens per sequence | ||
const int n_s = dst->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT( dst->ne[0] == nr); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); | ||
|
||
const float * src0_d = (const float *)src0->data; | ||
const float * src1_d = (const float *)src1->data; | ||
float * dst_d = (float *)dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
ssm_conv_f32_cuda(src0_d, src1_d, | ||
src0->nb[0], src0->nb[1], src0->nb[2], | ||
src1->nb[1], | ||
dst_d, | ||
dst->nb[0], dst->nb[1], dst->nb[2], | ||
nc, ncs, nr, n_t, n_s, | ||
stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#include "ssm_scan.cuh" | ||
|
||
template <int block_size> | ||
static __global__ void ssm_scan_f32( | ||
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, | ||
const float * __restrict__ src4, const float * __restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, | ||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb0, const int src2_nb1, const int src2_nb2, | ||
const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, | ||
const int src5_nb1, const int src5_nb2, | ||
float * __restrict__ dst, | ||
const int nc, const int nr, const int n_t, const int n_s) { | ||
|
||
// const int row = blockIdx.x*blockDim.y + threadIdx.y; | ||
const int tid = threadIdx.x; | ||
const int i3 = threadIdx.y; | ||
|
||
const int ith = tid; | ||
const int nth = WARP_SIZE; | ||
|
||
// rows per thread | ||
const int dr = (nr + nth - 1)/nth; | ||
|
||
// row range for this thread | ||
const int ir0 = dr*ith; | ||
const int ir1 = min(ir0 + dr, nr); | ||
const int ir = ir1 - ir0; | ||
for (int i2 = 0; i2 < n_t; ++i2) { | ||
const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} | ||
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} | ||
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} | ||
const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} | ||
const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} | ||
const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} | ||
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} | ||
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} | ||
|
||
// use the output as the source for the next token-wise iterations | ||
if (i2 > 0) { s0 = s; } | ||
|
||
// d_inner | ||
for (int i1 = 0; i1 < ir; ++i1) { | ||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 | ||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; | ||
float x_dt = x[i1] * dt_soft_plus; | ||
float sumf = 0.0f; | ||
// d_state | ||
#pragma unroll | ||
for (int i0 = 0; i0 < nc; ++i0) { | ||
int i = i0 + i1*nc; | ||
// state = prev_state * dA + dB * x | ||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); | ||
// y = rowwise_dotprod(state, C) | ||
sumf += state * C[i0]; | ||
s[i] = state; | ||
} | ||
y[i1] = sumf; | ||
} | ||
} | ||
} | ||
|
||
static void ssm_scan_f32_cuda( | ||
const float * src0, const float * src1, const float * src2, const float * src3, | ||
const float * src4, const float * src5, | ||
const int src0_nb1, const int src0_nb2, | ||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb0, const int src2_nb1, const int src2_nb2, | ||
const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, | ||
const int src5_nb1, const int src5_nb2, | ||
float * dst, | ||
const int nc, const int nr, const int n_t, const int n_s, | ||
cudaStream_t stream) { | ||
|
||
const dim3 block_dims(WARP_SIZE, n_s, 1); | ||
const int nblocks = 1; // TODO | ||
|
||
ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>( | ||
src0, src1, src2, src3, | ||
src4, src5, | ||
src0_nb1, src0_nb2, | ||
src1_nb0, src1_nb1, src1_nb2, src1_nb3, | ||
src2_nb0, src2_nb1, src2_nb2, | ||
src3_nb1, | ||
src4_nb1, src4_nb2, | ||
src5_nb1, src5_nb2, | ||
dst, | ||
nc, nr, n_t, n_s); | ||
} | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const struct ggml_tensor * src0 = dst->src[0]; // s | ||
const struct ggml_tensor * src1 = dst->src[1]; // x | ||
const struct ggml_tensor * src2 = dst->src[2]; // dt | ||
const struct ggml_tensor * src3 = dst->src[3]; // A | ||
const struct ggml_tensor * src4 = dst->src[4]; // B | ||
const struct ggml_tensor * src5 = dst->src[5]; // C | ||
|
||
const int64_t nc = src0->ne[0]; // d_state | ||
const int64_t nr = src0->ne[1]; // d_inner | ||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence | ||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch | ||
|
||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); | ||
GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src2->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src3->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src4->nb[0] == sizeof(float)); | ||
GGML_ASSERT(src5->nb[0] == sizeof(float)); | ||
// required for the dot product between s and C | ||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); | ||
// required for per-sequence offsets for states | ||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); | ||
// required to get correct offset for state destination (i.e. src1->nb[3]) | ||
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); | ||
|
||
const float * src0_d = (const float *)src0->data; | ||
const float * src1_d = (const float *)src1->data; | ||
const float * src2_d = (const float *)src2->data; | ||
const float * src3_d = (const float *)src3->data; | ||
const float * src4_d = (const float *)src4->data; | ||
const float * src5_d = (const float *)src5->data; | ||
float * dst_d = (float *)dst->data; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
|
||
ssm_scan_f32_cuda( | ||
src0_d, src1_d, src2_d, src3_d, | ||
src4_d, src5_d, | ||
src0->nb[1], src0->nb[2], | ||
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], | ||
src2->nb[0], src2->nb[1], src2->nb[2], | ||
src3->nb[1], | ||
src4->nb[1], src4->nb[2], | ||
src5->nb[1], src5->nb[2], | ||
dst_d, | ||
nc, nr, n_t, n_s, | ||
stream); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include "common.cuh" | ||
|
||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with
i2
shifts the view by one column.In practice, for Mamba (and Mamba-2) models,
nc
is always 4, which might help with unrolling.To coalesce memory accesses (at least for large prompts), I guess each warp could operate on
WARP_SIZE/nc
steps at a time overi2
, assuming theWARP_SIZE
is a multiple of 4 (is that always the case?), but this might need special handling of cases wherei2
is not evenly divided by that.I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good Idea,I am currently testing according to your method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use one warp per iteration of
nc
with each thread calculating a partial sum, then combine the partial sums viawarp_reduce_sum
and have the first thread in the warp write back the result.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated.