Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml:Mamba Cuda kernel performance improve #9186

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f809568
Add initial/naive CUDA kernels for the GGML_OP_SSM_CONV and GGML_OP_S…
jploski Jun 1, 2024
cc365b0
Add GGML_OP_SSM_CONF, GGML_OP_SSM_SCAN to supported ops for CUDA back…
jploski Jun 1, 2024
25f9e65
Update CUDA ops ssm_conv and ssm_scan to match CPU implementation fro…
jploski Jun 2, 2024
64fbd32
Add patch to test cases provided by @compilade; test for ssm_conv fails
jploski Jun 2, 2024
12c913c
Fix backend test for ssm_conv CUDA op not working
jploski Jun 2, 2024
061e520
Update CUDA ops and tests to match implementation from commit 8fb57ac…
jploski Jun 3, 2024
fae826f
Fix failed assertions while running Falcon Mamba
jploski Aug 25, 2024
20d390b
10x performance improve 4 cuda ssm conv & scan
piDack Aug 26, 2024
8dd323b
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 27, 2024
b423a6d
fix ssm_scan numerical error & others update
piDack Aug 27, 2024
40f4787
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 27, 2024
1928967
resolve test-backend-ops conflicts
piDack Aug 27, 2024
21c16fa
fix trailing whitespace
piDack Aug 27, 2024
e53b14f
del debug ingo
piDack Aug 27, 2024
eec0e8c
memory access pattern
piDack Aug 27, 2024
0e682ce
add restrict
piDack Aug 27, 2024
5999d6d
fix conflicts
piDack Aug 28, 2024
316a049
add restrict for dst
piDack Aug 29, 2024
99f2ac1
Merge branch 'master' of github.com:ggerganov/llama.cpp into mfalcon_…
piDack Aug 29, 2024
63b6e73
recommit for ci pass
piDack Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"

#include "ggml-cuda/ssm_conv.cuh"
#include "ggml-cuda/ssm_scan.cuh"
#include <algorithm>
#include <array>
#include <atomic>
Expand Down Expand Up @@ -2313,6 +2314,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
Expand Down Expand Up @@ -2894,6 +2900,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
Expand Down
100 changes: 100 additions & 0 deletions ggml/src/ggml-cuda/ssm_conv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "ssm_conv.cuh"

template <int block_size>
static __global__ void ssm_conv_f32(
const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * __restrict__ dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {

const int tid = blockIdx.y;
const int i3 = blockIdx.x;
const int i2 = threadIdx.x;

const int ith = tid;
const int nth = WARP_SIZE;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;

// {d_conv - 1 + n_t, d_inner, n_seqs}
// sliding window
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}

// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f;

// d_conv
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Copy link
Collaborator

@compilade compilade Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Good Idea,I am currently testing according to your method.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling at a loss for optimization ideas.

Use one warp per iteration of nc with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum and have the first thread in the warp write back the result.

Copy link
Contributor Author

@piDack piDack Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated.

}
x[i1] = sumf;
}
}

static void ssm_conv_f32_cuda(
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {

const dim3 block_dims(n_t, 1, 1);
//const int nblocks = n_s; // TODO
const dim3 grid_dims(n_s, WARP_SIZE, 1);

ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>(
src0, src1,
src0_nb0, src0_nb1, src0_nb2,
src1_nb1,
dst,
dst_nb0, dst_nb1, dst_nb2,
nc, ncs, nr, n_t, n_s);
}

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight

const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch

GGML_ASSERT( dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d,
src0->nb[0], src0->nb[1], src0->nb[2],
src1->nb[1],
dst_d,
dst->nb[0], dst->nb[1], dst->nb[2],
nc, ncs, nr, n_t, n_s,
stream);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm_conv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
144 changes: 144 additions & 0 deletions ggml/src/ggml-cuda/ssm_scan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#include "ssm_scan.cuh"

template <int block_size>
static __global__ void ssm_scan_f32(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3,
const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * __restrict__ dst,
const int nc, const int nr, const int n_t, const int n_s) {

// const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int i3 = threadIdx.y;

const int ith = tid;
const int nth = WARP_SIZE;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = min(ir0 + dr, nr);
const int ir = ir1 - ir0;
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s}

// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }

// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
}
y[i1] = sumf;
}
}
}

static void ssm_scan_f32_cuda(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5,
const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2,
const int src3_nb1,
const int src4_nb1, const int src4_nb2,
const int src5_nb1, const int src5_nb2,
float * dst,
const int nc, const int nr, const int n_t, const int n_s,
cudaStream_t stream) {

const dim3 block_dims(WARP_SIZE, n_s, 1);
const int nblocks = 1; // TODO

ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
src0, src1, src2, src3,
src4, src5,
src0_nb1, src0_nb2,
src1_nb0, src1_nb1, src1_nb2, src1_nb3,
src2_nb0, src2_nb1, src2_nb2,
src3_nb1,
src4_nb1, src4_nb2,
src5_nb1, src5_nb2,
dst,
nc, nr, n_t, n_s);
}

void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C

const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch

GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const float * src2_d = (const float *)src2->data;
const float * src3_d = (const float *)src3->data;
const float * src4_d = (const float *)src4->data;
const float * src5_d = (const float *)src5->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

ssm_scan_f32_cuda(
src0_d, src1_d, src2_d, src3_d,
src4_d, src5_d,
src0->nb[1], src0->nb[2],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
src2->nb[0], src2->nb[1], src2->nb[2],
src3->nb[1],
src4->nb[1], src4->nb[2],
src5->nb[1], src5->nb[2],
dst_d,
nc, nr, n_t, n_s,
stream);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm_scan.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9120,9 +9120,9 @@ static struct ggml_tensor * llm_build_mamba(

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
B = ggml_rms_norm(ctx, B, norm_rms_eps);
C = ggml_rms_norm(ctx, C, norm_rms_eps);
dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
}

// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
Expand Down
Loading