Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support Jamba hybrid Transformer-Mamba models #7531

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2541,7 +2541,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
}
llama_kv_cache_clear(lctx);
llama_past_clear(lctx);
llama_synchronize(lctx);
llama_reset_timings(lctx);
}
Expand Down
114 changes: 114 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,120 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA

def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"

def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]

self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]

if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")

# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:

# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

yield new_name, data_torch
return

new_name = self.map_tensor_name(name)

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield new_name, data_torch

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down
4 changes: 2 additions & 2 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ int main(int argc, char ** argv) {

const auto t_pp_start = ggml_time_us();

llama_kv_cache_clear(ctx);
llama_past_clear(ctx);

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
Expand All @@ -162,7 +162,7 @@ int main(int argc, char ** argv) {

if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
}

for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_past_seq_cp(context, 0, Int32(i), -1, -1)
}

if n_parallel > 1 {
Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
//// assign the system KV cache to all parallel sequences
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
//for (int32_t i = 1; i < n_parallel; ++i) {
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
// llama_past_seq_cp(ctx, 0, i, -1, -1);
//}

if (n_parallel > 1) {
Expand Down
2 changes: 1 addition & 1 deletion examples/cvector-generator/cvector-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
}

static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
const struct llama_model * model = llama_get_model(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);

// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
Expand Down
4 changes: 2 additions & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);

Expand Down Expand Up @@ -98,7 +98,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);

llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

Expand Down
2 changes: 1 addition & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down
4 changes: 2 additions & 2 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;

Expand Down
4 changes: 2 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,7 @@ int main(int argc, char ** argv) {

test t(inst, lmodel, ctx);

llama_kv_cache_clear(ctx);
llama_past_clear(ctx);

// cool off before the test
if (params.delay) {
Expand Down Expand Up @@ -1549,7 +1549,7 @@ int main(int argc, char ** argv) {
}

for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);

uint64_t t_start = get_time_ns();

Expand Down
8 changes: 4 additions & 4 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
}

batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);
llama_past_clear(context);

const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
Expand All @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(

LOGi("Benchmark text generation (tg)");

llama_kv_cache_clear(context);
llama_past_clear(context);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {

Expand All @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(

const auto t_tg_end = ggml_time_us();

llama_kv_cache_clear(context);
llama_past_clear(context);

const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
Expand Down Expand Up @@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
llama_past_clear(reinterpret_cast<llama_context *>(context));
}
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ actor LlamaContext {
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true

llama_kv_cache_clear(context)
llama_past_clear(context)

let t_pp_start = ggml_time_us()

Expand All @@ -229,7 +229,7 @@ actor LlamaContext {

// bench text generation

llama_kv_cache_clear(context)
llama_past_clear(context)

let t_tg_start = ggml_time_us()

Expand All @@ -248,7 +248,7 @@ actor LlamaContext {

let t_tg_end = ggml_time_us()

llama_kv_cache_clear(context)
llama_past_clear(context)

let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
Expand Down Expand Up @@ -298,7 +298,7 @@ actor LlamaContext {
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
llama_past_clear(context)
}

private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
Expand Down
13 changes: 7 additions & 6 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_past_seq_cp(ctx, 0, s, -1, -1);
}

const auto t_enc_end = ggml_time_us();
Expand Down Expand Up @@ -438,17 +438,18 @@ int main(int argc, char ** argv) {

// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, -1, n_past, -1);

if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
llama_past_seq_keep(ctx, seq_id_best);
llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_past_seq_rm (ctx, seq_id_best, -1, -1);

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_past_seq_cp(ctx, 0, s, -1, -1);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ int main(int argc, char ** argv){

// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, 0, n_past, -1);

llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
Expand Down
Loading
Loading