Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : initial Mamba-2 support #9126

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,77 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@Model.register("Mamba2ForCausalLM")
class Mamba2Model(Model):
model_arch = gguf.MODEL_ARCH.MAMBA2

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1

rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")

if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"

new_name = self.map_tensor_name(name)

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield (new_name, data_torch)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down
3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * D);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
256 changes: 189 additions & 67 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C) {
struct ggml_tensor * C,
struct ggml_tensor * D) {
GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));

{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_seq_tokens = x->ne[1];
const int64_t n_seqs = x->ne[2];

GGML_ASSERT(s->ne[2] == n_seqs);
GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
const int64_t head_dim = x->ne[0];
const int64_t n_head = x->ne[1];
const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];

GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(s->ne[3] == n_seqs);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs);
GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(D->ne[0] == n_head);
GGML_ASSERT(ggml_is_vector(D));

if (ggml_is_vector(A)) {
// Mamba-2
GGML_ASSERT(A->ne[0] == n_head);
} else {
// Mamba-1
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
}
}

bool is_node = false;
Expand All @@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A;
result->src[4] = B;
result->src[5] = C;
result->src[6] = D;

return result;
}
Expand Down Expand Up @@ -10190,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));

if (nb10 == sizeof(float)) {
if (ne00 > 1 && ne10 == 1) {
// fast broadcast path
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);

const float scale = src1_ptr[0];

if (scale == 0.0f) {
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
// but it is useful when resetting the state of recurrent models.
memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float));
} else {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float));
}
if (scale != 1.0f) {
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
}
}
}
} else if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
Expand Down Expand Up @@ -15840,20 +15887,25 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}

const int ith = params->ith;
const int nth = params->nth;

const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // dim
const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src0->ne[3]; // number of sequences in the batch

const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);

GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
Expand All @@ -15862,51 +15914,121 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(float));
// allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng);

// heads per thread
const int dh = (nh + nth - 1)/nth;

// head range for this thread
const int ih0 = dh*ith;
const int ih1 = MIN(ih0 + dh, nh);

for (int i3 = 0; i3 < ns; ++i3) {
for (int i2 = 0; i2 < nt; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
const float * D = (const float *) ((const char *) src6->data); // {nh}
float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}

// use the output as the source when it's not the first token-wise iteration
if (i2 > 0) { s0 = s; }

// rows per thread
const int dr = (nr + nth - 1)/nth;
if (ggml_is_vector(src3)) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);

for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}

// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
const int np = (nc & ~(GGML_F32_STEP - 1));

// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };

GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);

GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
GGML_F32_VEC az[GGML_F32_ARR];

for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);

ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);

ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);

sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);

GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
}
}

// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#else
const int np = 0;
#endif
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf + x[ii] * D[h];
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states

// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];

// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
// and also because expf is used within the loop.
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf + x[ii] * D[h];
}
}
y[i1] = sumf;
}
}
}
Expand Down
Loading
Loading