Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : initial Mamba-2 support #9126

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 76 additions & 31 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
Expand Down Expand Up @@ -591,6 +592,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
Expand Down Expand Up @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];
struct ggml_tensor * src7 = node->src[7];

GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
GGML_ASSERT(src6);
GGML_ASSERT(src7);

size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;
size_t offs_src7 = 0;

id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
id<MTLBuffer> id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil;

const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);

const uint64_t nb30 = src3->nb[0];
const uint64_t nb31 = src3->nb[1];

const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);

const uint64_t nb40 = src4->nb[0];
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];

const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);

const uint64_t nb50 = src5->nb[0];
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];

const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);

const uint64_t nb60 = src6->nb[0];

const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70);

const uint64_t nb70 = src7->nb[0];

const int64_t d_state = ne00;
const int64_t d_inner = ne01;
const int64_t n_head = ne02;
const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne11;
const int64_t n_seqs = ne02;
const int64_t n_seqs = ne13;

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
if (ne30 == 1) {
// Mamba-2
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];

[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];

[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];

[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_src7 offset:offs_src7 atIndex:7];
[encoder setBuffer:id_dst offset:offs_dst atIndex:8];

[encoder setBytes:&d_state length:sizeof(d_state) atIndex:9];
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10];
[encoder setBytes:&n_head length:sizeof(n_head) atIndex:11];
[encoder setBytes:&n_group length:sizeof(n_group) atIndex:12];
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13];
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14];

[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25];
[encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26];
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27];
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28];
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29];
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30];
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31];
[encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32];
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33];
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34];
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35];
[encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36];
[encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37];
[encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38];
compilade marked this conversation as resolved.
Show resolved Hide resolved

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break;
case GGML_OP_MUL_MAT:
{
Expand Down
146 changes: 126 additions & 20 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32(
x[0] = sumf;
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
// TODO: optimize
kernel void kernel_ssm_scan_f32(
device const void * src0,
Expand All @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32(
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device const void * src7,
device float * dst,
constant int64_t & d_state,
constant int64_t & d_inner,
constant int64_t & n_head,
constant int64_t & n_group,
constant int64_t & n_seq_tokens,
constant int64_t & n_seqs,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
Expand All @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32(
constant uint64_t & nb40,
constant uint64_t & nb41,
constant uint64_t & nb42,
constant uint64_t & nb43,
constant uint64_t & nb50,
constant uint64_t & nb51,
constant uint64_t & nb52,
constant uint64_t & nb53,
constant uint64_t & nb60,
constant uint64_t & nb70,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i3 = tgpig.y;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq

const int64_t nc = d_state;
const int64_t nr = d_inner;
const int64_t nh = n_head;
const int64_t ng = n_group;
const int64_t n_t = n_seq_tokens;
const int64_t n_s = n_seqs;

const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);

device const int32_t * ids = (device const int32_t *) src7;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);

if (i2 > 0) {
s0 = s;
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}

const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}

// i1 == 0
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
float x_dt = x[0] * dt_soft_plus;
y[0] = sumf + x[0] * D[0];

// recurse
s0 = s;
}
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device const void * src7,
device float * dst,
constant int64_t & d_state,
constant int64_t & d_inner,
constant int64_t & n_head,
constant int64_t & n_group,
constant int64_t & n_seq_tokens,
constant int64_t & n_seqs,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb20,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb30,
constant uint64_t & nb31,
constant uint64_t & nb40,
constant uint64_t & nb41,
constant uint64_t & nb42,
constant uint64_t & nb43,
constant uint64_t & nb50,
constant uint64_t & nb51,
constant uint64_t & nb52,
constant uint64_t & nb53,
constant uint64_t & nb60,
constant uint64_t & nb70,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq

const int64_t nc = d_state;
const int64_t nr = d_inner;
const int64_t nh = n_head;
const int64_t ng = n_group;
const int64_t n_t = n_seq_tokens;
const int64_t n_s = n_seqs;

const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float);

device const int32_t * ids = (device const int32_t *) src7;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns}

const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = expf(dt_soft_plus * A[0]);
float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
int64_t i = i0;
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}

y[0] = sumf;
y[0] = sumf + x[0] * D[0];

// recurse
s0 = s;
}
}

Expand Down
Loading