Skip to content

Commit

Permalink
llama : add llama_model_is_recurrent to simplify figuring that out
Browse files Browse the repository at this point in the history
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
  • Loading branch information
compilade committed Aug 21, 2024
1 parent b264edd commit 1be5ea7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ extern "C" {
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);

// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
Expand Down
12 changes: 9 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3292,8 +3292,7 @@ static bool llama_kv_cache_init(

cache.has_shift = false;

// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;

cache.head = 0;
Expand Down Expand Up @@ -17235,7 +17234,7 @@ struct llama_context * llama_new_context_with_model(
ggml_type type_v = params.type_v;

// Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) {
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
Expand Down Expand Up @@ -17709,6 +17708,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
return model->hparams.dec_start_token_id;
}

bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
default: return false;
}
}

uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
Expand Down

0 comments on commit 1be5ea7

Please sign in to comment.