From 271104c65c9b99d5b5aca4489d7bec103cd60db9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 3 Apr 2024 11:07:16 -0400 Subject: [PATCH 01/28] wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. --- llama.cpp | 1386 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 960 insertions(+), 426 deletions(-) diff --git a/llama.cpp b/llama.cpp index 267ac4cc022a1..9ca8ca0f41320 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1793,14 +1793,14 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_r() const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s() const { // dimension of the recurrent state embeddings // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -1904,7 +1904,6 @@ struct llama_layer { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -1925,9 +1924,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -1947,9 +1943,365 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * k : k_l) { + size += ggml_nrows(k) * ggml_row_size(k->type, k->ne[0]); + } + for (struct ggml_tensor * v : v_l) { + size += ggml_nrows(v) * ggml_row_size(v->type, v->ne[0]); + } + return size; + } +}; + +// for recurrent models, use a tree of sequences to simplify +// quickly finding the tail cell of each sequence +// TODO: drop the _rs_ infix +struct llama_rs_seq_node { + llama_seq_id seq_id = -1; + int32_t next_cell = -1; + + // needed for automatic typecasting with .find() + llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} + + bool operator<(const llama_rs_seq_node & other) const { + return seq_id < other.seq_id; + } + + bool is_tail() const { + return next_cell < 0; + } +}; + +struct llama_rs_cell { + llama_pos pos = -1; + int32_t src = -1; // copy source id (cleared next when -1) + + // Link to previous cell in this sequence. + // Sequences can only diverge, never converge, + // so this works when there are multiple seq_ids per cell too. + int32_t prev = -1; + + // ref count of tails (should match the number of next_cell == -1 in seq_nodes) + uint32_t tail_rc = 0; + + // seq_ids by insertion order, to simplify updating n_cells compared to a set + std::vector seq_nodes; + + llama_rs_seq_node * get_node(const llama_seq_id & id) { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return &seq_nodes[i]; + } + } + return nullptr; + } + + void insert_node(const llama_rs_seq_node & node) { + llama_rs_seq_node * node_dest = get_node(node.seq_id); + if (node_dest == nullptr) { + seq_nodes.push_back(node); + } else { + *node_dest = node; + } + } + + bool remove_node(llama_rs_seq_node * node_ptr) { + if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { + size_t offset = node_ptr - seq_nodes.data(); + if (offset % sizeof(llama_rs_seq_node) == 0) { + offset /= sizeof(llama_rs_seq_node); + if (offset < seq_nodes.size()) { + for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { + seq_nodes[i - 1] = seq_nodes[i]; + } + seq_nodes.resize(seq_nodes.size() - 1); + return true; + } + } + } + return false; + } + + bool has_seq_id(const llama_seq_id & id) const { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return true; + } + } + return false; + } + + bool is_empty() const { + return seq_nodes.empty(); + } +}; + + +struct llama_rs_seq_meta { + // cell id of the latest state of this seq_id + int32_t tail = -1; + // number of cells for which this seq_id is the first + // (useful to know if cells in this sequence should be pruned) + int32_t n_cells = 0; + // whether the tail is a cell part of multiple sequences + bool shared = false; +}; + +// ring-buffer of cached recurrent state data +struct llama_rs_cache { + bool do_copy = false; + + uint32_t head = 0; // first state used for the last slot + uint32_t size = 0; + uint32_t used = 0; + + // computed when finding a slot + uint32_t n = 0; // range of states used for the last slot + + // useful to know the minimum reserved cell count per seq_id + // only counts sequences with n_cells > 0 + uint32_t n_seqs = 0; + + // with state models, a cell can hold the state for more than one past token + // TODO: it's probably not possible to always use contiguous cells + std::vector cells; + + // find tail cells faster + std::vector seq_tails; // map seq_ids to cell ids + + // per layer + // NOTE: the naming of r and s is arbitrary + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states + + // returns whether or not a cell was freed + bool clear_cell(uint32_t i) { + if (i < size) { + llama_rs_cell & rs_cell = cells[i]; + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + // update tail + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = new_tail.seq_nodes.size() > 1; + } else { + seq.shared = false; + } + } + // cell counts + if (first) { + seq.n_cells -= 1; + if (seq.n_cells == 0) { + GGML_ASSERT(seq.tail < 0); + n_seqs -= 1; + } + first = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + return true; + } + } + return false; + } + + // TODO: maybe use a simpler data structure than a tree + // returns whether or not a cell was freed + bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto * node_ptr = rs_cell.get_node(id); // search once + if (node_ptr != nullptr) { + if (rs_cell.seq_nodes.size() == 1) { + return clear_cell(i_cell); + } else { + // update tree + llama_rs_seq_node node = *node_ptr; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_ptr == rs_cell.seq_nodes.data()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; + } + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = node_ptr[1]; + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; + } + if (other_no_longer_shared) { + next_seq.shared = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + const bool removed = rs_cell.remove_node(node_ptr); + GGML_ASSERT(removed); + } + } + } + return false; + } + + bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < seq_tails.size()) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto & seq = seq_tails[id]; + int32_t prev = rs_cell.prev; + if ((uint32_t) seq.tail == i_cell) { + // the cell is already the tail of this seq_id + return false; + } + if (rs_cell.is_empty()) { + prev = seq.tail; + } + // ensure the new tail won't mess up the tree + GGML_ASSERT(seq.tail == -1 || seq.tail == prev); + if (prev >= 0 && (uint32_t) prev < size) { + // the targeted cell has a previous cell + llama_rs_cell & prev_cell = cells[prev]; + llama_rs_seq_node * prev_node = prev_cell.get_node(id); + GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken + if (rs_cell.pos < 0) { + GGML_ASSERT(rs_cell.is_empty()); + rs_cell.pos = prev_cell.pos + 1; + rs_cell.src = prev_cell.src; + } + prev_cell.tail_rc -= 1; + prev_node->next_cell = i_cell; + } + if (rs_cell.is_empty()) { + // only add after potential failures above + if (seq.n_cells == 0) { + n_seqs += 1; + } + seq.n_cells += 1; + // set pos if still unset + if (rs_cell.pos < 0) { + rs_cell.pos = 0; + rs_cell.src = -1; + } + } + // the target cell was not already a tail of this seq_id + rs_cell.insert_node(id); // next_cell == -1 by default + rs_cell.tail_rc += 1; + seq.tail = i_cell; + seq.shared = rs_cell.seq_nodes.size() > 1; + return true; + } + return false; + } + + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + // (but this over-prunes when the system prompt doesn't take lots of cells) + // Hmm. The system prompt does not need checkpoints... + size_t min_cells_per_seq() const { + return size / (n_seqs > 0 ? n_seqs : 1); + } + + // each seq_id can have at most this many cells + // (ignoring seqs which behave as a shared prompt) + // TODO: avoid recalculating system seq_ids + // (to use when pruning (to avoid over-pruning)) + // NOTE: this also limits the shared prompt to at most half the cells + // (but the shared prompt technically needs only one cell...) + // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) + size_t max_cells_per_seq() const { + int32_t n_system_seqs = 0; + int32_t n_system_cells = 0; + for (size_t i = 0; i < seq_tails.size(); ++i) { + auto & seq = seq_tails[i]; + if (seq.tail >= 0 && (size_t) seq.tail < size) { + if (seq.shared && seq.n_cells > 0) { + n_system_seqs += 1; + n_system_cells += seq.n_cells; + } + } + } + int32_t n_other_seqs = n_seqs - n_system_seqs; + return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + } + + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * r : r_l) { + size += ggml_nrows(r) * ggml_row_size(r->type, r->ne[0]); + } + for (struct ggml_tensor * s : s_l) { + size += ggml_nrows(s) * ggml_row_size(s->type, s->ne[0]); + } + return size; + } +}; + +struct llama_cache { + // key + value cache for self attention + llama_kv_cache kv; + + // recurrent state cache for state space models + llama_rs_cache rs; + std::vector ctxs; std::vector bufs; + // NOTE: padding may make this bigger than kv.total_size() + rs.total_size() size_t total_size() const { size_t size = 0; for (ggml_backend_buffer_t buf : bufs) { @@ -1958,7 +2310,7 @@ struct llama_kv_cache { return size; } - ~llama_kv_cache() { + ~llama_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -2146,8 +2498,8 @@ struct llama_context { const llama_model & model; - // key + value cache for the self attention - struct llama_kv_cache kv_self; + // key + value cache for self-attention, and/or recurrent state cache + struct llama_cache cache; std::mt19937 rng; @@ -2205,9 +2557,9 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_copy; // I32 [n_rs] + struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] + struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] // control vectors struct llama_control_vector cvec; @@ -2221,47 +2573,45 @@ struct llama_context { // kv cache helpers // -static bool llama_kv_cache_init( - struct llama_kv_cache & cache, +static bool llama_cache_init( + struct llama_cache & cache, const llama_model & model, ggml_type type_k, ggml_type type_v, - uint32_t kv_size, + uint32_t n_ctx, + uint32_t n_seq_max, bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const int64_t n_layer = hparams.n_layer; - cache.has_shift = false; + // TODO: per layer n_embd_* + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_r = hparams.n_embd_r(); + const uint32_t n_embd_s = hparams.n_embd_s(); + const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; + const bool has_r = n_embd_r != 0; + const bool has_s = n_embd_s != 0; + const bool has_rs = has_r || has_s; + const uint32_t kv_size = has_kv ? n_ctx : 0; + const uint32_t rs_size = has_rs ? n_seq_max : 0; + // TODO: per cache type layer count + const int64_t n_layer = hparams.n_layer; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.kv.size = kv_size; - // TODO: support mixed reccurent Transformer architectues - // NOTE: (!a || b) is a logical implication (a -> b) - GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.kv.type_k = type_k; + cache.kv.type_v = type_v; - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + cache.kv.cells.clear(); + cache.kv.cells.resize(kv_size); - cache.type_k = type_k; - cache.type_v = type_v; + cache.rs.size = rs_size; - cache.cells.clear(); - cache.cells.resize(kv_size); - - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } + cache.rs.cells.clear(); + cache.rs.cells.resize(rs_size); + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -2282,7 +2632,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -2295,17 +2645,37 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (has_kv) { + cache.kv.k_l.reserve(n_layer); + cache.kv.v_l.reserve(n_layer); + } + if (has_r) { + cache.rs.r_l.reserve(n_layer); + } + if (has_s) { + cache.rs.s_l.reserve(n_layer); + } for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + if (has_kv) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.kv.k_l.push_back(k); + cache.kv.v_l.push_back(v); + } + if (has_r) { + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_format_name(r, "cache_r_l%d", i); + cache.rs.r_l.push_back(r); + } + if (has_s) { + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_format_name(s, "cache_s_l%d", i); + cache.rs.s_l.push_back(s); + } } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -2330,23 +2700,30 @@ static bool llama_kv_cache_init( // Note: On success, it's important that cache.head points // to the first cell of the slot. static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { - const uint32_t n_ctx = cache.size; + struct llama_cache & cache, + const struct llama_batch & batch) { + const uint32_t kv_size = cache.kv.size; + const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.recurrent) { + if (rs_size > 0) { // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. + // each cache cell can store the state for a whole sequence. + // TODO: real ring-buffer of states + // TODO: state chekpoints (multiple cells per sequence) + // TODO: find a way to always make the rs slot contiguous + + // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size + - llama_seq_id min = cache.size - 1; + llama_seq_id min = cache.rs.size - 1; llama_seq_id max = 0; for (uint32_t i = 0; i < n_tokens; ++i) { for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { + if ((uint32_t) seq_id < rs_size) { if (seq_id > max) { max = seq_id; } @@ -2354,83 +2731,93 @@ static bool llama_kv_cache_find_slot( min = seq_id; } // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; + if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { + cache.rs.used += 1; } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set + cache.rs.cells[seq_id].pos = batch.pos[i]; + cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } } // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; + cache.rs.head = min; + cache.rs.n = max - min + 1; // sanity check - return max >= min; - } - // otherwise, one cell per token. - - if (n_tokens > n_ctx) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); - return false; + if (max < min) { + return false; + } } - uint32_t n_tested = 0; + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, kv_size); + return false; + } - while (true) { - if (cache.head + n_tokens > n_ctx) { - n_tested += n_ctx - cache.head; - cache.head = 0; - continue; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (cache.kv.head > cache.kv.used + 2*n_tokens) { + cache.kv.head = 0; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; + uint32_t n_tested = 0; + + while (true) { + if (cache.kv.head + n_tokens > kv_size) { + n_tested += kv_size - cache.kv.head; + cache.kv.head = 0; + continue; } - } - if (found) { - break; - } + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.kv.cells[cache.kv.head + i].pos >= 0) { + found = false; + cache.kv.head += i + 1; + n_tested += i + 1; + break; + } + } - if (n_tested >= n_ctx) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + if (found) { + break; + } + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } } - } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; + for (uint32_t i = 0; i < n_tokens; i++) { + cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + } } - } - cache.used += n_tokens; + cache.kv.used += n_tokens; + } return true; } -// find how many cells are currently in use +// find how many KV cells are currently in use static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -2443,214 +2830,381 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_rs_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +static void llama_cache_clear(struct llama_cache & cache) { + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + kv_cell.pos = -1; + kv_cell.delta = 0; + kv_cell.seq_id.clear(); + } + cache.kv.has_shift = false; + cache.kv.do_defrag = false; + cache.kv.head = 0; + cache.kv.used = 0; + } + if (cache.rs.size > 0) { + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.seq_nodes.clear(); + } + cache.rs.do_copy = false; + cache.rs.head = 0; + cache.rs.used = 0; + cache.rs.n_seqs = 0; + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(cache.rs.size); } - cache.head = 0; - cache.used = 0; } -static bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_rm( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { + if (seq_id >= (int64_t) cache.rs.size) { // could be fatal - return false; - } - if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; + return n_past; + } + uint32_t new_head = cache.rs.size; + // adjust p0 and p1 according to the states found + llama_pos new_p0 = 0; + llama_pos new_p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos < p0) { + // move forward the new p0 further + if (rs_cell.pos >= new_p0) { + new_p0 = rs_cell.pos + 1; + } + } else if (rs_cell.pos >= p1) { + // move back the new p1 further + if (rs_cell.pos < new_p1) { + new_p1 = rs_cell.pos; + } + if (rs_cell.pos >= n_past) { + n_past = rs_cell.pos + 1; + } + } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) + if (seq_id < 0) { + cache.rs.clear_cell(i); + } else { // (rs_cell.has_seq_id(seq_id)) + cache.rs.remove_seq_from_cell(i, seq_id); + } + if (rs_cell.is_empty() && new_head == cache.rs.size) { + new_head = i; + } + } } } + p0 = new_p0; + p1 = new_p1; + // correctly set n_past when there's nothing after p1 + if (n_past < p0) { n_past = p0; } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cache.cells[i].is_empty()) { - // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; - cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; + if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + if (seq_id < 0) { + kv_cell.seq_id.clear(); + } else { // (kv_cell.has_seq_id(seq_id)) + kv_cell.seq_id.erase(seq_id); + } + if (kv_cell.is_empty()) { + // keep count of the number of used cells + if (kv_cell.pos >= 0) { cache.kv.used--; } + + kv_cell.pos = -1; + if (new_head == cache.kv.size) { new_head = i; } + } + } else { + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } } } - } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; + } + } - return true; + return n_past; } -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { +static llama_pos llama_cache_seq_cp( + struct llama_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; - - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); - } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + // TODO: in practice this seems to be only used on whole sequences; + // should partial sequence copy be removed? + + llama_pos n_past = 0; + + if (cache.rs.size > 0) { + // have to start from beginning for recurrent models + p0 = 0; + if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { + auto seq_src = cache.rs.seq_tails[seq_id_src]; + int32_t src_tail = seq_src.tail; + // find the last tail of src in the pos range + while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { + llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; + if (tail_cell.pos < p1) { + break; + } + src_tail = tail_cell.prev; } - cache.do_copy = true; + uint32_t new_head = cache.rs.size; - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { + if (i == (uint32_t) src_tail) { + // need to be inserted in order, but there's only one + cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + } else { + // keep only the tail cell of the source + // assuming a copy means no rollback will be attempted afterwards + cache.rs.remove_seq_from_cell(i, seq_id_src); + if (new_head == cache.rs.size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - return; + p1 = n_past; } - // otherwise, this is the KV cache of a Transformer-like model - - cache.head = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { + kv_cell.seq_id.insert(seq_id_dst); + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } } } + + return n_past; } -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; +static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (!kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= 0) cache.kv.used--; + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) new_head = i; + } else { + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(seq_id); + } + } - for (uint32_t i = 0; i < cache.size; ++i) { - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; } -static void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_add( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + auto & seq = cache.rs.seq_tails[seq_id]; + // follow the sequence from its tail + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + int32_t i = cell_id; + cell_id = rs_cell.prev; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + if (rs_cell.pos < 0) { + // NOTE: this affects the other sequences which share the cell + cache.rs.clear_cell(i); + // TODO: update cache.rs.head + } + } + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + kv_cell.pos += delta; + kv_cell.delta += delta; - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; + if (kv_cell.pos < 0) { + if (!kv_cell.is_empty()) { + cache.kv.used--; + } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { + new_head = i; + } + } } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { - new_head = i; + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.kv.head = new_head != cache.kv.size ? new_head : 0; } - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + return n_past; } -static void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { +static llama_pos llama_cache_seq_div( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + auto & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; + } + cell_id = rs_cell.prev; + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; - { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + { + llama_pos p_old = kv_cell.pos; + kv_cell.pos /= d; + kv_cell.delta += kv_cell.pos - p_old; + } + } + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; + } } } } + + return n_past; } -static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { llama_pos result = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + if (cache.rs.size > 0) { + int32_t cell_id = cache.rs.seq_tails[seq_id].tail; + if (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + result = rs_cell.pos; + } + // exit early + return result; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + result = std::max(result, kv_cell.pos); + } } } @@ -6009,6 +6563,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_batch & batch; const llama_kv_cache & kv_self; + const llama_rs_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -6034,8 +6589,10 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_rs; const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_head; const int32_t n_orig_ctx; const enum llama_pooling_type pooling_type; @@ -6058,7 +6615,8 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), batch (batch), - kv_self (lctx.kv_self), + kv_self (lctx.cache.kv), + rs_self (lctx.cache.rs), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -6081,8 +6639,10 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + n_rs (worst_case ? rs_self.size : rs_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), + kv_head (worst_case ? kv_self.size - n_tokens : kv_self.head), + rs_head (worst_case ? 0 : rs_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -6148,29 +6708,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6267,21 +6804,21 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; } struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_rs); cb(lctx.inp_s_mask, "inp_s_mask", -1); ggml_set_input(lctx.inp_s_mask); return lctx.inp_s_mask; } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -9269,26 +9806,31 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx0, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); + } // clear states of sequences which are starting at the beginning of this batch { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); + conv_states = ggml_mul(ctx0, conv_states, state_mask); + ssm_states = ggml_mul(ctx0, ssm_states, state_mask); } - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_rs); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -9321,8 +9863,8 @@ struct llm_build_context { // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); // extract x from x_conv x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); @@ -9348,15 +9890,15 @@ struct llm_build_context { // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, // because only a single tensor can be returned. struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); // store last states (the second part of y_ssm_states) ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); + ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); @@ -9558,23 +10100,6 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_s_copy(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -9729,26 +10254,14 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + const int64_t kv_size = lctx.cache.kv.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; + data[i] = lctx.cache.kv.cells[i].delta; } } @@ -9759,7 +10272,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; + const auto & kv_self = lctx.cache.kv; + const auto & rs_self = lctx.cache.rs; if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -9835,7 +10349,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { f = -INFINITY; } else { f = 0.0f; @@ -9886,7 +10400,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_pos->data; for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); + data[i] = float(kv_self.cells[i].pos); } } @@ -9943,29 +10457,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + if (rs_self.size > 0) { + const int64_t n_rs = rs_self.n; if (lctx.inp_s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // clear unused states + for (int i = 0; i < n_rs; ++i) { + uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) rs_cell.src >= 0; - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (rs_cell.src < 0) { + rs_cell.src = cell_id; } } } + + // checkpoints require copies between cells + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + const uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; + + // prevent out-of-bound sources + if (rs_cell.src < 0 || (uint32_t) rs_cell.src >= rs_self.size) { + rs_cell.src = cell_id; + } + + data[i] = rs_cell.src; + + // ensure copy only happens once + if (rs_cell.src != (int32_t) cell_id) { + rs_cell.src = cell_id; + } + } + } + // For Mamba (and other recurrent architectures), // update the correct state(s)/sequence(s) for each token of the batch. + // Each row contains relative cell ids of the sequences for the associated token. // Like with the KQ_mask, if a token in the batch has multiple sequences, // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). if (lctx.inp_s_seq) { @@ -9978,12 +10517,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int32_t n_seq = batch.n_seq_id[j]; GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence - for (int i = 0; i < n_kv; ++i) { + for (int i = 0; i < n_rs; ++i) { if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + llama_seq_id seq_id = batch.seq_id[j][i]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // all sequences of this batch should already be initialized + GGML_ASSERT(seq.tail >= 0); + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); + + data[j*n_rs + i] = seq.tail - rs_self.head; } else { - data[j*n_kv + i] = -1; + data[j*n_rs + i] = -1; } } } @@ -10129,7 +10676,8 @@ static int llama_decode_internal( //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; + auto & rs_self = lctx.cache.rs; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -10245,17 +10793,11 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { return 1; } - if (!kv_self.recurrent) { + if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -10329,11 +10871,15 @@ static int llama_decode_internal( // update the kv ring buffer { kv_self.head += n_tokens; + rs_self.head += rs_self.n; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + if (rs_self.head >= rs_self.size) { + rs_self.head = 0; + } } #ifdef GGML_PERF @@ -10430,7 +10976,7 @@ static int llama_decode_internal( // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; const auto & hparams = lctx.model.hparams; @@ -10651,7 +11197,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.cache.kv.has_shift) { { ggml_backend_sched_reset(lctx.sched); @@ -10667,7 +11213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; kv_self.has_shift = false; @@ -10677,39 +11223,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed - if (lctx.kv_self.do_defrag) { + if (lctx.cache.kv.do_defrag) { llama_kv_cache_defrag_internal(lctx); need_reserve = true; - lctx.kv_self.do_defrag = false; + lctx.cache.kv.do_defrag = false; } // reserve a worst case graph again @@ -14258,18 +14778,8 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -14377,25 +14887,42 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx->model, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + if (ctx->cache.rs.size > 0) { + size_t memory_size_r = 0; + size_t memory_size_s = 0; + + for (auto & r : ctx->cache.rs.r_l) { + memory_size_r += ggml_nbytes(r); + } + + for (auto & s : ctx->cache.rs.s_l) { + memory_size_s += ggml_nbytes(s); + } + + LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); + } + if (ctx->cache.kv.size > 0) { size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { + for (auto & k : ctx->cache.kv.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { + for (auto & v : ctx->cache.kv.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -14513,7 +15040,11 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; + if (ctx->cache.rs.size > 0) { + return ctx->cache.rs.size; + } else { + return ctx->cache.kv.size; + } } enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { @@ -14799,8 +15330,9 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { - view->n_cells = int32_t(ctx->kv_self.size); + const llama_kv_cache & kv_self = ctx->cache.kv; + if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -14809,7 +15341,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = ctx->kv_self.cells; + const std::vector & kv_cells = kv_self.cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -14818,7 +15350,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -14856,67 +15388,77 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != ctx->kv_self.used) { + if (uint32_t(used_cells) != kv_self.used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, ctx->kv_self.used, used_cells); + __func__, kv_self.used, used_cells); } } int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; - for (uint32_t i = 0; i < ctx->kv_self.size; i++) { - result += ctx->kv_self.cells[i].seq_id.size(); + for (uint32_t i = 0; i < ctx->cache.kv.size; i++) { + result += ctx->cache.kv.cells[i].seq_id.size(); } return result; } int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { - return ctx->kv_self.used; + return ctx->cache.kv.used; } void llama_kv_cache_clear(struct llama_context * ctx) { - llama_kv_cache_clear(ctx->kv_self); + llama_cache_clear(ctx->cache); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } + llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + return n_past >= p0; } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + uint32_t n_seq_max = llama_n_seq_max(ctx); + if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { + return; + } if (seq_id_src == seq_id_dst) { return; } - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_cache_seq_keep(ctx->cache, seq_id); } void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (delta == 0) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (d == 1) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_pos_max(ctx->cache, seq_id); } void llama_kv_cache_defrag(struct llama_context * ctx) { - llama_kv_cache_defrag(ctx->kv_self); + llama_kv_cache_defrag(ctx->cache.kv); } void llama_kv_cache_update(struct llama_context * ctx) { @@ -14944,9 +15486,10 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); - const size_t s_kv = ctx->kv_self.total_size(); + const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); - const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; + const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; + // TODO: rs cache cells const size_t s_total = ( + s_rng_size @@ -15241,14 +15784,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } + // FIXME: set rs cache too // set kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; uint32_t kv_head; @@ -15279,16 +15823,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size); - inp += v_size; - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size); @@ -15303,8 +15837,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; - ctx->kv_self.used = kv_used; + ctx->cache.kv.head = kv_head; + ctx->cache.kv.used = kv_used; for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; @@ -15313,13 +15847,13 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); - ctx->kv_self.cells[i].pos = pos; + ctx->cache.kv.cells[i].pos = pos; llama_seq_id seq_id; for (size_t j = 0; j < seq_id_size; ++j) { memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); - ctx->kv_self.cells[i].seq_id.insert(seq_id); + ctx->cache.kv.cells[i].seq_id.insert(seq_id); } } } From 8db1e4d45fb27a5e76ac55559a008a425e00fbac Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 4 Apr 2024 10:46:43 -0400 Subject: [PATCH 02/28] llama : use std::find for seq_nodes in llama_rs_cache --- llama.cpp | 153 ++++++++++++++++++++++-------------------------------- 1 file changed, 61 insertions(+), 92 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9ca8ca0f41320..6dc310bf94c6c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1962,11 +1962,12 @@ struct llama_rs_seq_node { llama_seq_id seq_id = -1; int32_t next_cell = -1; - // needed for automatic typecasting with .find() + // needed for automatic typecasting from a llama_seq_id llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} - bool operator<(const llama_rs_seq_node & other) const { - return seq_id < other.seq_id; + // needed for more convenient std::find + bool operator==(const llama_rs_seq_node & other) const { + return seq_id == other.seq_id; } bool is_tail() const { @@ -1989,48 +1990,18 @@ struct llama_rs_cell { // seq_ids by insertion order, to simplify updating n_cells compared to a set std::vector seq_nodes; - llama_rs_seq_node * get_node(const llama_seq_id & id) { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return &seq_nodes[i]; - } - } - return nullptr; - } - void insert_node(const llama_rs_seq_node & node) { - llama_rs_seq_node * node_dest = get_node(node.seq_id); - if (node_dest == nullptr) { + auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); + if (node_dest == seq_nodes.end()) { seq_nodes.push_back(node); } else { + // overwrite the pre-existing node with the same seq_id if it exists *node_dest = node; } } - bool remove_node(llama_rs_seq_node * node_ptr) { - if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { - size_t offset = node_ptr - seq_nodes.data(); - if (offset % sizeof(llama_rs_seq_node) == 0) { - offset /= sizeof(llama_rs_seq_node); - if (offset < seq_nodes.size()) { - for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { - seq_nodes[i - 1] = seq_nodes[i]; - } - seq_nodes.resize(seq_nodes.size() - 1); - return true; - } - } - } - return false; - } - bool has_seq_id(const llama_seq_id & id) const { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return true; - } - } - return false; + return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); } bool is_empty() const { @@ -2132,67 +2103,65 @@ struct llama_rs_cache { bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { llama_rs_cell & rs_cell = cells[i_cell]; - auto * node_ptr = rs_cell.get_node(id); // search once - if (node_ptr != nullptr) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + if (node_iter != rs_cell.seq_nodes.end()) { if (rs_cell.seq_nodes.size() == 1) { return clear_cell(i_cell); - } else { - // update tree - llama_rs_seq_node node = *node_ptr; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - cells[node.next_cell].prev = rs_cell.prev; + } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; } - if (node_ptr == rs_cell.seq_nodes.data()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = node_ptr[1]; - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); + if (other_no_longer_shared) { + next_seq.shared = false; } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); } - } else { - GGML_ASSERT(false && "invalid seq_id"); } - const bool removed = rs_cell.remove_node(node_ptr); - GGML_ASSERT(removed); + } else { + GGML_ASSERT(false && "invalid seq_id"); } + rs_cell.seq_nodes.erase(node_iter); } } return false; @@ -2215,8 +2184,8 @@ struct llama_rs_cache { if (prev >= 0 && (uint32_t) prev < size) { // the targeted cell has a previous cell llama_rs_cell & prev_cell = cells[prev]; - llama_rs_seq_node * prev_node = prev_cell.get_node(id); - GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken if (rs_cell.pos < 0) { GGML_ASSERT(rs_cell.is_empty()); @@ -2267,7 +2236,7 @@ struct llama_rs_cache { int32_t n_system_seqs = 0; int32_t n_system_cells = 0; for (size_t i = 0; i < seq_tails.size(); ++i) { - auto & seq = seq_tails[i]; + const auto & seq = seq_tails[i]; if (seq.tail >= 0 && (size_t) seq.tail < size) { if (seq.shared && seq.n_cells > 0) { n_system_seqs += 1; From 0028010d01447c079f98bc33f06fca691fc99905 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Apr 2024 09:54:35 -0400 Subject: [PATCH 03/28] llama : state checkpoints for recurrent models --- ggml.c | 96 +++---- llama.cpp | 751 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 585 insertions(+), 262 deletions(-) diff --git a/ggml.c b/ggml.c index c9b0a6a0ef776..7a3f1b7a2f882 100644 --- a/ggml.c +++ b/ggml.c @@ -6335,19 +6335,18 @@ struct ggml_tensor * ggml_ssm_conv( GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_matrix(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(ggml_is_vector(sq)); GGML_ASSERT(sq->type == GGML_TYPE_I32); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t n_rs = s->ne[2]; GGML_ASSERT( s->ne[0] == d_conv - 1); GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + GGML_ASSERT(sq->ne[0] == n_tokens); bool is_node = false; @@ -6356,8 +6355,8 @@ struct ggml_tensor * ggml_ssm_conv( is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6410,7 +6409,7 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -15087,9 +15086,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int nc = src2->ne[0]; // d_conv const int nr = src0->ne[1]; // d_inner const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int n_rs = src0->ne[2]; // max number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -15106,10 +15105,12 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src3->data; // {n_tokens} + + if (n_rs > 1) { // multiple sequences means it's hard to know when it's the first time a state is read, // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); // can't use memcpy because of d_conv vs d_conv - 1 @@ -15123,19 +15124,19 @@ static void ggml_compute_forward_ssm_conv_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int32_t sq_i = sq[i2]; + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} + float * s0; // {d_conv - 1, d_inner, n_rs} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} int ne0s0; - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} ne0s0 = src0->ne[0]; } else { // the source is the last (d_conv - 1) columns of the destination @@ -15153,18 +15154,6 @@ static void ggml_compute_forward_ssm_conv_f32( s[(nc - 1) + i1*nc] = x0[i1]; } - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product @@ -15216,7 +15205,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nc = src0->ne[0]; // d_state const int64_t nr = src0->ne[1]; // d_inner const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15225,6 +15214,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states @@ -15240,10 +15230,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src6->data; // {n_tokens} + + if (n_rs > 1) { // it's hard to know if the source states have already been copied // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); memcpy(s, s0, nc*ir*sizeof(float)); @@ -15251,21 +15243,21 @@ static void ggml_compute_forward_ssm_scan_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + int32_t sq_i = sq[i2]; + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} } else { // otherwise the source is the same as the destination s0 = s; @@ -15288,18 +15280,6 @@ static void ggml_compute_forward_ssm_scan_f32( } y[i1] = sumf; } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } } } diff --git a/llama.cpp b/llama.cpp index 6dc310bf94c6c..d561f80f62b6d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2016,11 +2016,13 @@ struct llama_rs_seq_meta { // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // whether the tail is a cell part of multiple sequences - bool shared = false; + // changing the tail cell of a sequence can only be done at batch boundary, + // this guards against changing the cell when it shouldn't be; + // should be cleared when done finding a slot + bool in_ubatch = false; }; -// ring-buffer of cached recurrent state data +// ring-buffered tree of cached recurrent state data struct llama_rs_cache { bool do_copy = false; @@ -2032,8 +2034,10 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 + // only counts sequences with n_cells > 0 AND which have a non-shared tail uint32_t n_seqs = 0; + // cells part of multiple sequences AND which have at least one tail + uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token // TODO: it's probably not possible to always use contiguous cells @@ -2047,127 +2051,332 @@ struct llama_rs_cache { std::vector r_l; // rolling/shift states std::vector s_l; // ssm (recurrent) states - // returns whether or not a cell was freed - bool clear_cell(uint32_t i) { - if (i < size) { - llama_rs_cell & rs_cell = cells[i]; - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; + // TODO: maybe use a simpler data structure than a tree + + // Inefficient, but thorough verification and rebuilding of the rs cache + // from only the cells list with `pos` and seq_ids. + // Should not be called in a hot loop except when desperate and/or debugging. + bool rebuild(bool debug) { + bool was_valid = true; + // the source of truth is the cells list + // buffer sizes + if (size != cells.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", + __func__, cells.size(), size); + } + cells.resize(size); + was_valid = false; + } + if (size != seq_tails.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", + __func__, seq_tails.size(), size); + } + seq_tails.resize(size); + was_valid = false; + } + // cells consistency + uint32_t used_verif = 0; + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.seq_nodes.empty()) { + if (cell.pos >= 0) { + cell.pos = -1; + was_valid = false; + } + } + if (cell.pos < 0) { + if (cell.pos != -1) { + cell.pos = -1; + was_valid = false; + } + if (!cell.seq_nodes.empty()) { + cell.seq_nodes.clear(); + was_valid = false; + } + cell.src = -1; + if (cell.prev != -1) { + cell.prev = -1; + was_valid = false; + } + } else if (!debug) { + // Assuming the cache should be actually rebuilt when not debugging + cell.src = cell_id; + } + if (!cell.seq_nodes.empty()) { + used_verif += 1; + } + } + if (used != used_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n", + __func__, used, used_verif); + } + used = used_verif; + was_valid = false; + } + // tail verification + std::vector> seq_cells; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + seq_cells.clear(); + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.has_seq_id(seq_id)) { + seq_cells.push_back({cell.pos, cell_id}); + } + } + // sort by pos and then by cell_id + std::sort(seq_cells.begin(), seq_cells.end()); + int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second; + if (tail != seq.tail) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.tail, tail); + } + seq.tail = tail; + was_valid = false; + } + int32_t prev = -1; + for (size_t i = 0; i < seq_cells.size(); ++i) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + if (cell.prev != prev) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, cell.prev, prev); } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = new_tail.seq_nodes.size() > 1; - } else { - seq.shared = false; - } - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (seq.n_cells == 0) { - GGML_ASSERT(seq.tail < 0); - n_seqs -= 1; - } - first = false; - } + cell.prev = prev; + was_valid = false; + } + prev = cell_id; + } + int32_t n_cells = 0; + int32_t next = -1; + for (size_t i = seq_cells.size(); i-- > 0;) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + // assuming it's always found, because how else would it end up in the list of cells for this seq_id? + auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); + if (seq_node == cell.seq_nodes.begin()) { + n_cells += 1; + } + if (seq_node->next_cell != next) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, seq_node->next_cell, next); + } + seq_node->next_cell = next; + was_valid = false; + } + next = cell_id; + } + if (seq.n_cells != n_cells) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.n_cells, n_cells); + } + seq.n_cells = n_cells; + } + // in_batch should only be true when in the process of finding a slot + if (seq.in_ubatch != false) { + if (debug) { + LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", + __func__, seq_id); + } + seq.in_ubatch = false; + was_valid = false; + } + } + // tail_rc + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + uint32_t tail_rc = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { + tail_rc += 1; + } + } + if (cell.tail_rc != tail_rc) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", + __func__, cell_id, cell.tail_rc, tail_rc); + } + cell.tail_rc = tail_rc; + was_valid = false; + } + } + // n_seqs + uint32_t n_seqs_verif = 0; + uint32_t n_shared_tail_cells_verif = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0) { + llama_rs_cell & tail_cell = cells[seq.tail]; + // NOTE: could also have checked if n_cells > 0 + if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { + if (tail_cell.seq_nodes.size() > 1) { + n_shared_tail_cells_verif += 1; } else { - GGML_ASSERT(false && "invalid seq_id"); + n_seqs_verif += 1; } } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - return true; } } - return false; + if (n_seqs != n_seqs_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", + __func__, n_seqs, n_seqs_verif); + } + n_seqs = n_seqs_verif; + was_valid = false; + } + if (n_shared_tail_cells != n_shared_tail_cells_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", + __func__, n_shared_tail_cells, n_shared_tail_cells_verif); + } + n_shared_tail_cells = n_shared_tail_cells_verif; + was_valid = false; + } + return was_valid; } - // TODO: maybe use a simpler data structure than a tree // returns whether or not a cell was freed - bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { - if (i_cell < size && (size_t) id < size) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once - if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - return clear_cell(i_cell); - } - // else update tree - llama_rs_seq_node node = *node_iter; + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work cells[node.next_cell].prev = rs_cell.prev; } + // next_cell of the nodes of the previous cell + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } if ((uint32_t) node.seq_id < seq_tails.size()) { auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + // update tail if (node.is_tail()) { seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; } - if (node_iter == rs_cell.seq_nodes.begin()) { - // this seq_id was the first in the list + // cell counts + if (first) { seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; + if (rs_cell.tail_rc > 0 && seq.tail < 0) { + // last tail cell + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); } + first = false; } } else { GGML_ASSERT(false && "invalid seq_id"); } - rs_cell.seq_nodes.erase(node_iter); } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + } + } + + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + // TODO: assert the iterator points inside the correct vector + if (node_iter != rs_cell.seq_nodes.end()) { + if (rs_cell.seq_nodes.size() == 1) { + clear_cell(rs_cell); + return rs_cell.seq_nodes.end(); + } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail < 0 && rs_cell.tail_rc == 1) { + // assuming the previous cell of a shared cell is also shared, + // (no need to update the shared tail cells count elsewhere, then) + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_seqs += 1; + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + return rs_cell.seq_nodes.erase(node_iter); + } + return node_iter; + } + + // returns whether or not the seq_id was removed + bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); } return false; } - bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; @@ -2194,10 +2403,11 @@ struct llama_rs_cache { } prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; + rs_cell.prev = prev; } if (rs_cell.is_empty()) { - // only add after potential failures above - if (seq.n_cells == 0) { + // either the sequence didn't own any cells or had a shared tail cell + if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { n_seqs += 1; } seq.n_cells += 1; @@ -2206,12 +2416,40 @@ struct llama_rs_cache { rs_cell.pos = 0; rs_cell.src = -1; } + used += 1; + } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { + // don't count shared-cell tails + // FIXME: make this saner + n_seqs -= 1; + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + // shared cell without a tail gets a tail; + // FIXME: don't prune, in case this is used in llama_cache_seq_cp + GGML_ASSERT(false); // make sure we don't get here by accident + // prune the other sequences out of this cell + // NOTE: have to inline the removal because the state tree is partially invalid + bool first = true; + for (auto & node : rs_cell.seq_nodes) { + GGML_ASSERT(node.seq_id != id); + GGML_ASSERT(node.next_cell >= 0); + // easy removal, none of the nodes are tails + llama_rs_cell & next_cell = cells[node.next_cell]; + next_cell.prev = rs_cell.prev; + if (first) { + auto & first_seq = seq_tails[node.seq_id]; + first_seq.n_cells -= 1; + first = false; + } + } + rs_cell.seq_nodes.clear(); + } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { + // this is correct as long as this isn't called when trying to find a slot + // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default rs_cell.tail_rc += 1; seq.tail = i_cell; - seq.shared = rs_cell.seq_nodes.size() > 1; return true; } return false; @@ -2219,33 +2457,12 @@ struct llama_rs_cache { // each seq_id should have access to at least this many cells // (to use when pruning (to avoid over-pruning)) - // (but this over-prunes when the system prompt doesn't take lots of cells) - // Hmm. The system prompt does not need checkpoints... - size_t min_cells_per_seq() const { - return size / (n_seqs > 0 ? n_seqs : 1); - } - - // each seq_id can have at most this many cells - // (ignoring seqs which behave as a shared prompt) - // TODO: avoid recalculating system seq_ids - // (to use when pruning (to avoid over-pruning)) - // NOTE: this also limits the shared prompt to at most half the cells - // (but the shared prompt technically needs only one cell...) - // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) - size_t max_cells_per_seq() const { - int32_t n_system_seqs = 0; - int32_t n_system_cells = 0; - for (size_t i = 0; i < seq_tails.size(); ++i) { - const auto & seq = seq_tails[i]; - if (seq.tail >= 0 && (size_t) seq.tail < size) { - if (seq.shared && seq.n_cells > 0) { - n_system_seqs += 1; - n_system_cells += seq.n_cells; - } - } + size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { + uint32_t seqs = n_seqs; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; } - int32_t n_other_seqs = n_seqs - n_system_seqs; - return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); } size_t total_size() const { @@ -2528,7 +2745,7 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] + struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -2657,7 +2874,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -2678,54 +2895,170 @@ static bool llama_kv_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: real ring-buffer of states - // TODO: state chekpoints (multiple cells per sequence) // TODO: find a way to always make the rs slot contiguous - // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size - - - llama_seq_id min = cache.rs.size - 1; - llama_seq_id max = 0; + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size if ((uint32_t) seq_id < rs_size) { - if (seq_id > max) { - max = seq_id; + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } } - if (seq_id < min) { - min = seq_id; + + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + if (looped_once) { + // avoid infinite loop + // NOTE: this should not happen, but gracefully fail anyway + LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); + return false; + } + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); - } - if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.rs.used += 1; + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); } - cache.rs.cells[seq_id].pos = batch.pos[i]; - cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? + // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; } // allow getting the range of used cells, from head to head + n - cache.rs.head = min; - cache.rs.n = max - min + 1; + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; // sanity check - if (max < min) { + if (max_seq < min_seq || max_cell < min_cell) { return false; } } @@ -2799,6 +3132,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } +// find how many recurrent state cells are currently in use static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_rs_cell & cell = cache.cells[i - 1]; @@ -2829,12 +3163,15 @@ static void llama_cache_clear(struct llama_cache & cache) { llama_rs_cell & rs_cell = cache.rs.cells[i]; rs_cell.pos = -1; rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; + cache.rs.n_shared_tail_cells = 0; cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(cache.rs.size); } @@ -2846,8 +3183,8 @@ static llama_pos llama_cache_seq_rm( llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -2863,7 +3200,9 @@ static llama_pos llama_cache_seq_rm( for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + + if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { // move forward the new p0 further if (rs_cell.pos >= new_p0) { @@ -2879,9 +3218,9 @@ static llama_pos llama_cache_seq_rm( } } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) if (seq_id < 0) { - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_from_cell(i, seq_id); + cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); } if (rs_cell.is_empty() && new_head == cache.rs.size) { new_head = i; @@ -2943,11 +3282,12 @@ static llama_pos llama_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } // TODO: in practice this seems to be only used on whole sequences; - // should partial sequence copy be removed? + // should partial sequence copy support be removed? + // TODO: What if the destination sequence is not empty? llama_pos n_past = 0; @@ -2973,11 +3313,11 @@ static llama_pos llama_cache_seq_cp( if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { if (i == (uint32_t) src_tail) { // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); } else { // keep only the tail cell of the source // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell(i, seq_id_src); + cache.rs.remove_seq_from_cell_id(i, seq_id_src); if (new_head == cache.rs.size) { new_head = i; } @@ -3009,16 +3349,41 @@ static llama_pos llama_cache_seq_cp( } static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.rs.size > 0) { + uint32_t new_head = cache.rs.size; + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.seq_nodes.empty()) { + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + if (node_iter->seq_id != seq_id) { + node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + } else { + node_iter = std::next(node_iter); + } + } + if (new_head == cache.rs.size && rs_cell.is_empty()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } + } + if (cache.kv.size > 0) { uint32_t new_head = cache.kv.size; for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; if (!kv_cell.has_seq_id(seq_id)) { - if (kv_cell.pos >= 0) cache.kv.used--; + if (kv_cell.pos >= 0) { cache.kv.used--; } kv_cell.pos = -1; kv_cell.seq_id.clear(); - if (new_head == cache.kv.size) new_head = i; + if (new_head == cache.kv.size) { new_head = i; } } else { kv_cell.seq_id.clear(); kv_cell.seq_id.insert(seq_id); @@ -3052,13 +3417,12 @@ static llama_pos llama_cache_seq_add( while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - int32_t i = cell_id; cell_id = rs_cell.prev; if (rs_cell.pos >= p0 && rs_cell.pos < p1) { rs_cell.pos += delta; if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); // TODO: update cache.rs.head } } @@ -6787,7 +7151,7 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -10482,26 +10846,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); int32_t * data = (int32_t *) lctx.inp_s_seq->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence - - for (int i = 0; i < n_rs; ++i) { - if (i < n_seq) { - llama_seq_id seq_id = batch.seq_id[j][i]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // all sequences of this batch should already be initialized - GGML_ASSERT(seq.tail >= 0); - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[j*n_rs + i] = seq.tail - rs_self.head; - } else { - data[j*n_rs + i] = -1; - } - } + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); + + data[i] = seq.tail - rs_self.head; } } } @@ -14874,7 +15227,7 @@ struct llama_context * llama_new_context_with_model( memory_size_s += ggml_nbytes(s); } - LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); @@ -14891,7 +15244,7 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -15458,7 +15811,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; - // TODO: rs cache cells + // FIXME: rs cache cells const size_t s_total = ( + s_rng_size @@ -15606,14 +15959,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat } } + // FIXME: copy rs cache // copy kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -15637,17 +15991,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - tmp_buf.resize(v_size); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size()); - data_ctx->write(tmp_buf.data(), tmp_buf.size()); - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); @@ -15753,7 +16096,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } - // FIXME: set rs cache too + // FIXME: set rs cache // set kv cache { const auto & kv_self = ctx->cache.kv; From 0c8b3b20956521acc8f1f297cb58ab3172b3c3e7 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 9 Apr 2024 17:35:22 -0400 Subject: [PATCH 04/28] llama : correctly handle more edge cases for the rs cache --- llama.cpp | 407 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 211 insertions(+), 196 deletions(-) diff --git a/llama.cpp b/llama.cpp index d561f80f62b6d..5433bde86796a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2034,7 +2034,7 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 AND which have a non-shared tail + // only counts sequences which have a non-shared tail uint32_t n_seqs = 0; // cells part of multiple sequences AND which have at least one tail uint32_t n_shared_tail_cells = 0; @@ -2082,21 +2082,37 @@ struct llama_rs_cache { llama_rs_cell & cell = cells[cell_id]; if (cell.seq_nodes.empty()) { if (cell.pos >= 0) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } } if (cell.pos < 0) { if (cell.pos != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } if (!cell.seq_nodes.empty()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", + __func__, cell_id, cell.seq_nodes.size()); + } cell.seq_nodes.clear(); was_valid = false; } cell.src = -1; if (cell.prev != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.prev); + } cell.prev = -1; was_valid = false; } @@ -2213,17 +2229,15 @@ struct llama_rs_cache { // n_seqs uint32_t n_seqs_verif = 0; uint32_t n_shared_tail_cells_verif = 0; - for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { - auto & seq = seq_tails[seq_id]; - if (seq.tail >= 0) { - llama_rs_cell & tail_cell = cells[seq.tail]; - // NOTE: could also have checked if n_cells > 0 - if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { - if (tail_cell.seq_nodes.size() > 1) { - n_shared_tail_cells_verif += 1; - } else { + for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { + llama_rs_cell & rs_cell = cells[cell_id]; + if (!rs_cell.seq_nodes.empty()) { + if (rs_cell.seq_nodes.size() == 1) { + if (rs_cell.tail_rc == 1) { n_seqs_verif += 1; } + } else if (rs_cell.tail_rc > 0) { + n_shared_tail_cells_verif += 1; } } } @@ -2246,72 +2260,15 @@ struct llama_rs_cache { return was_valid; } - // returns whether or not a cell was freed - void clear_cell(llama_rs_cell & rs_cell) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; - } - // next_cell of the nodes of the previous cell - if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { - llama_rs_cell & prev_cell = cells[rs_cell.prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); - // assuming the previous node is always found - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); - prev_node->next_cell = node.next_cell; - if (node.is_tail()) { - prev_cell.tail_rc += 1; - } - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (rs_cell.tail_rc > 0 && seq.tail < 0) { - // last tail cell - if (rs_cell.seq_nodes.size() > 1) { - n_shared_tail_cells -= 1; - } else { - n_seqs -= 1; - } - } - first = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - } - } - // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); // TODO: assert the iterator points inside the correct vector if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - clear_cell(rs_cell); - return rs_cell.seq_nodes.end(); - } - // else update tree + // update the tree llama_rs_seq_node node = *node_iter; if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail cells[node.next_cell].prev = rs_cell.prev; } if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { @@ -2321,6 +2278,14 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + if (prev_cell.seq_nodes.size() > 1) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + if (rs_cell.seq_nodes.size() == 1) { + n_seqs -= 1; + } + } prev_cell.tail_rc += 1; } } @@ -2328,11 +2293,15 @@ struct llama_rs_cache { auto & seq = seq_tails[node.seq_id]; if (node.is_tail()) { seq.tail = rs_cell.prev; - if (seq.tail < 0 && rs_cell.tail_rc == 1) { - // assuming the previous cell of a shared cell is also shared, - // (no need to update the shared tail cells count elsewhere, then) - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; + if (rs_cell.tail_rc == 1) { + if (rs_cell.seq_nodes.size() > 1) { + // assuming the previous cell of a shared cell is also shared, + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } else if (seq.tail < 0) { + // no more tail, no more sequence + n_seqs -= 1; + } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; @@ -2341,21 +2310,30 @@ struct llama_rs_cache { // this seq_id was the first in the list seq.n_cells -= 1; - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_seqs += 1; + auto next_node = std::next(node_iter); + if (next_node != rs_cell.seq_nodes.end()) { + // the next node is the new first one, so update its n_cells + if ((uint32_t) next_node->seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node->seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_shared_tail_cells -= 1; + n_seqs += 1; + } } + } else { + GGML_ASSERT(false && "invalid seq_id"); } } else { - GGML_ASSERT(false && "invalid seq_id"); + // this was the last seq_id of the cell + used -= 1; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + // the other fields *should* have already been updated elsewhere } } } else { @@ -2366,6 +2344,13 @@ struct llama_rs_cache { return node_iter; } + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + node_iter = remove_seq_node_from_cell(rs_cell, node_iter); + } + } + // returns whether or not the seq_id was removed bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { @@ -2404,47 +2389,63 @@ struct llama_rs_cache { prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; + if (seq.tail == prev) { + // What to do when the tail moves... + // from unique to shared (n_seqs--) + // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) + // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) + // from unique to unique (seq.n_cells++) + // from empty to unique (seq.n_cells++, n_seqs++) + // from empty to shared + // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to shared + // if the last cell has no tails (n_shared_tail_cells--) + // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to unique (seq.n_cells++) + // if this seq_id was not the first of the last cell (n_seqs++) + // if the last cell has no tails (n_shared_tail_cells--) + if (prev_cell.seq_nodes.size() > 1) { + // from shared + if (rs_cell.is_empty()) { + // to unique + if (prev_cell.seq_nodes[0].seq_id != id) { + n_seqs += 1; + } + } + // the previous cell is no longer a shared tail + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells -= 1; + } + } else if (!rs_cell.is_empty()) { + // from unique to shared + n_seqs -= 1; + } + } } if (rs_cell.is_empty()) { - // either the sequence didn't own any cells or had a shared tail cell - if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { - n_seqs += 1; - } + // to unique seq.n_cells += 1; - // set pos if still unset - if (rs_cell.pos < 0) { + if (seq.tail < 0) { + // from empty to unique + n_seqs += 1; + // pos was not yet set rs_cell.pos = 0; rs_cell.src = -1; } used += 1; - } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { - // don't count shared-cell tails - // FIXME: make this saner - n_seqs -= 1; - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { - // shared cell without a tail gets a tail; - // FIXME: don't prune, in case this is used in llama_cache_seq_cp - GGML_ASSERT(false); // make sure we don't get here by accident - // prune the other sequences out of this cell - // NOTE: have to inline the removal because the state tree is partially invalid - bool first = true; - for (auto & node : rs_cell.seq_nodes) { - GGML_ASSERT(node.seq_id != id); - GGML_ASSERT(node.next_cell >= 0); - // easy removal, none of the nodes are tails - llama_rs_cell & next_cell = cells[node.next_cell]; - next_cell.prev = rs_cell.prev; - if (first) { - auto & first_seq = seq_tails[node.seq_id]; - first_seq.n_cells -= 1; - first = false; + } else { + // to shared + if (rs_cell.seq_nodes.size() == 1) { + // a lone tail becomes a shared cell + if (rs_cell.tail_rc > 0) { + n_seqs -= 1; } + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + n_shared_tail_cells += 1; } - rs_cell.seq_nodes.clear(); - } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { - // this is correct as long as this isn't called when trying to find a slot - // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default @@ -2977,6 +2978,7 @@ static bool llama_kv_cache_find_slot( llama_rs_cell & candidate = cache.rs.cells[cell_id]; if (candidate.is_empty()) { break; } if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail if (candidate.seq_nodes.size() > 1) { // prune out the other seq_ids, because they diverge // TODO(maybe): hande this in insert_seq_tail_to_cell_id @@ -3198,40 +3200,42 @@ static llama_pos llama_cache_seq_rm( llama_pos new_p0 = 0; llama_pos new_p1 = std::numeric_limits::max(); - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // partial seq_id removal has to happen from the tail + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + // copy before the cell is potentially changed + int32_t prev_id = rs_cell.prev; + if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { + // non-tail removal for shared cells can only be done when clearing a cell + // (i.e. when the next cell's link to the previous cell can be safely changed) + p1 = rs_cell.pos + 1; + } + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // if the node isn't found, the sequence tree is malformed + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + // get the smallest removed cell id + if (new_head > (uint32_t) cell_id) { new_head = cell_id; } + } else { + // one more than the biggest non-removed cell of this sequence + if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { - // move forward the new p0 further - if (rs_cell.pos >= new_p0) { - new_p0 = rs_cell.pos + 1; - } - } else if (rs_cell.pos >= p1) { - // move back the new p1 further - if (rs_cell.pos < new_p1) { - new_p1 = rs_cell.pos; - } - if (rs_cell.pos >= n_past) { - n_past = rs_cell.pos + 1; - } - } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) - if (seq_id < 0) { - cache.rs.clear_cell(rs_cell); - } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); - } - if (rs_cell.is_empty() && new_head == cache.rs.size) { - new_head = i; - } + // new_p0 should be right after the max pos in the states before p0 + if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } + } else { // (rs_cell.pos >= p1) + // new_p1 should be the min pos in the states after p1 + if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } } } + cell_id = prev_id; } p0 = new_p0; p1 = new_p1; - // correctly set n_past when there's nothing after p1 - if (n_past < p0) { n_past = p0; } // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.rs.size && new_head < cache.rs.head) { @@ -3259,10 +3263,8 @@ static llama_pos llama_cache_seq_rm( kv_cell.pos = -1; if (new_head == cache.kv.size) { new_head = i; } } - } else { - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + } else if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; } } } @@ -3292,42 +3294,37 @@ static llama_pos llama_cache_seq_cp( llama_pos n_past = 0; if (cache.rs.size > 0) { - // have to start from beginning for recurrent models + // have to start from the beginning for recurrent models p0 = 0; if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { - auto seq_src = cache.rs.seq_tails[seq_id_src]; - int32_t src_tail = seq_src.tail; - // find the last tail of src in the pos range - while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { - llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; - if (tail_cell.pos < p1) { - break; - } - src_tail = tail_cell.prev; - } - - uint32_t new_head = cache.rs.size; - + int32_t src_head = -1; + int32_t head_pos = p1; + int32_t src_next = -1; + // find the start of the sequence for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { - if (i == (uint32_t) src_tail) { - // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); - } else { - // keep only the tail cell of the source - // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell_id(i, seq_id_src); - if (new_head == cache.rs.size) { - new_head = i; - } + if (!rs_cell.is_empty() && rs_cell.prev < 0) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + if (seq_node != rs_cell.seq_nodes.end()) { + src_head = i; + head_pos = rs_cell.pos; + src_next = seq_node->next_cell; + break; } } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.rs.size && new_head < cache.rs.head) { - cache.rs.head = new_head; + while (src_head >= 0 && head_pos < p1) { + cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); + src_head = src_next; + if (head_pos >= n_past) { n_past = head_pos + 1; } + if (src_next >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[src_next]; + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + head_pos = rs_cell.pos; + // it should always be found if the seq tree is valid + GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); + src_next = seq_node->next_cell; + } } } p1 = n_past; @@ -3338,9 +3335,7 @@ static llama_pos llama_cache_seq_cp( llama_kv_cell & kv_cell = cache.kv.cells[i]; if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { kv_cell.seq_id.insert(seq_id_dst); - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } @@ -3352,18 +3347,19 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (!rs_cell.seq_nodes.empty()) { - for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { - if (node_iter->seq_id != seq_id) { - node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - } else { - node_iter = std::next(node_iter); - } - } - if (new_head == cache.rs.size && rs_cell.is_empty()) { - new_head = i; + // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (i == (uint32_t) seq_id) { continue; } + llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + cell_id = rs_cell.prev; + if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { + new_head = cell_id; } } } @@ -3414,6 +3410,7 @@ static llama_pos llama_cache_seq_add( auto & seq = cache.rs.seq_tails[seq_id]; // follow the sequence from its tail int32_t cell_id = seq.tail; + uint32_t new_head = cache.rs.size; while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; @@ -3423,13 +3420,19 @@ static llama_pos llama_cache_seq_add( if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell cache.rs.clear_cell(rs_cell); - // TODO: update cache.rs.head + if (new_head > (uint32_t) cell_id) { + new_head = cell_id; + } } } if (n_past <= rs_cell.pos) { n_past = rs_cell.pos + 1; } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.rs.head = new_head != cache.rs.size ? new_head : 0; } if (cache.kv.size > 0) { @@ -3474,8 +3477,8 @@ static llama_pos llama_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -11275,6 +11278,10 @@ static int llama_decode_internal( } } n_outputs_prev += lctx.n_outputs; + +#ifndef NDEBUG + GGML_ASSERT(lctx.cache.rs.rebuild(true)); +#endif } // wait for the computation to finish (automatically done when obtaining the model output) @@ -16332,11 +16339,19 @@ void llama_batch_free(struct llama_batch batch) { int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + return ret; } From a09db95eabb5f75a5534f804882cf82e1bb5cadd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Apr 2024 10:24:45 -0400 Subject: [PATCH 05/28] llama : rename many llama_kv_cache_* functions --- llama.cpp | 97 +++++++++++++++++++++++++++++++++++++++---------------- llama.h | 72 ++++++++++++++++++++++++++++++++++------- 2 files changed, 131 insertions(+), 38 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9d887c6dbfe29..f972c3472a278 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2032,7 +2032,6 @@ struct llama_rs_seq_meta { // ring-buffered tree of cached recurrent state data struct llama_rs_cache { - bool do_copy = false; uint32_t head = 0; // first state used for the last slot uint32_t size = 0; @@ -2769,7 +2768,7 @@ struct llama_context { }; // -// kv cache helpers +// kv and rs cache helpers // static bool llama_cache_init( @@ -2898,7 +2897,7 @@ static bool llama_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +static bool llama_cache_find_slot( struct llama_cache & cache, const struct llama_batch & batch) { const uint32_t kv_size = cache.kv.size; @@ -3181,7 +3180,6 @@ static void llama_cache_clear(struct llama_cache & cache) { rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } - cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; @@ -3412,8 +3410,8 @@ static llama_pos llama_cache_seq_add( llama_pos p1, llama_pos delta) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -3535,7 +3533,7 @@ static llama_pos llama_cache_seq_div( } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { - llama_pos result = 0; + llama_pos result = -1; if (cache.rs.size > 0) { int32_t cell_id = cache.rs.seq_tails[seq_id].tail; @@ -11174,7 +11172,7 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { + if (!llama_cache_find_slot(lctx.cache, u_batch)) { return 1; } @@ -15790,6 +15788,10 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k } } +bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug) { + return ctx->cache.rs.rebuild(debug); +} + int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; @@ -15804,55 +15806,96 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.kv.used; } -void llama_kv_cache_clear(struct llama_context * ctx) { +int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { + return ctx->cache.rs.used; +} + +void llama_cache_clear(struct llama_context * ctx) { llama_cache_clear(ctx->cache); } +// deprecated +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_cache_clear(ctx); +} + +llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); +} + +// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } - llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); return n_past >= p0; } -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + +llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { - return; + return 0; } if (seq_id_src == seq_id_dst) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +// deprecated +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } llama_cache_seq_keep(ctx->cache, seq_id); } -void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +// deprecated +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_cache_seq_keep(ctx, seq_id); +} + +llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (delta == 0) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } -void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +// deprecated +void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_cache_seq_add(ctx, seq_id, p0, p1, delta); +} + +llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (d == 1) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } -llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } +// deprecated +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + llama_cache_seq_div(ctx, seq_id, p0, p1, d); +} + +llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } return llama_cache_seq_pos_max(ctx->cache, seq_id); } +// deprecated +llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + return max_pos < 0 ? 0 : max_pos; +} + void llama_kv_cache_defrag(struct llama_context * ctx) { llama_kv_cache_defrag(ctx->cache.kv); } @@ -16597,7 +16640,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(cache, batch)) { + if (!llama_cache_find_slot(cache, batch)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; diff --git a/llama.h b/llama.h index b770a275ff02f..c211ca592a5df 100644 --- a/llama.h +++ b/llama.h @@ -515,6 +515,12 @@ extern "C" { // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Rebuild and check the validity of the recurrent state cache's tree of sequences. + // (slow, use only for debugging purposes) + // Returns whether or not the rs cache was valid. + // The errors are always corrected, but only logged when debug is true. + LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -522,36 +528,60 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache - LLAMA_API void llama_kv_cache_clear( + // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them) + LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV and recurrent state caches + LLAMA_API void llama_cache_clear( struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_cache_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -559,12 +589,20 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_cache_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -572,17 +610,29 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_cache_seq_div instead"); - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the largest position present in the KV and/or RS cache for the specified sequence + LLAMA_API llama_pos llama_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From b6fafd174721c930e89b27df7de6ee776ace9ade Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Apr 2024 12:59:43 -0400 Subject: [PATCH 06/28] llama : remove useless return value for some llama_cache_* functions --- llama.cpp | 47 ++++++++++++----------------------------------- llama.h | 14 +++++++------- 2 files changed, 19 insertions(+), 42 deletions(-) diff --git a/llama.cpp b/llama.cpp index 92bff6b907b8f..15f7ca43a6dc8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2887,7 +2887,6 @@ static bool llama_cache_init( bool offload) { const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -3010,6 +3009,8 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + // FIXME: on failure, leave all caches in a consistent state. + if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. @@ -3509,7 +3510,7 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id } } -static llama_pos llama_cache_seq_add( +static void llama_cache_seq_add( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3519,8 +3520,6 @@ static llama_pos llama_cache_seq_add( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted auto & seq = cache.rs.seq_tails[seq_id]; @@ -3541,9 +3540,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } // If we freed up a slot, set head to it so searching can start there. @@ -3573,9 +3569,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } @@ -3583,11 +3576,9 @@ static llama_pos llama_cache_seq_add( // Otherwise we just start the next search from the beginning. cache.kv.head = new_head != cache.kv.size ? new_head : 0; } - - return n_past; } -static llama_pos llama_cache_seq_div( +static void llama_cache_seq_div( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3596,8 +3587,6 @@ static llama_pos llama_cache_seq_div( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed auto & seq = cache.rs.seq_tails[seq_id]; @@ -3609,9 +3598,6 @@ static llama_pos llama_cache_seq_div( rs_cell.pos /= d; } cell_id = rs_cell.prev; - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } } @@ -3628,14 +3614,9 @@ static llama_pos llama_cache_seq_div( kv_cell.delta += kv_cell.pos - p_old; } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } } - - return n_past; } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { @@ -16935,13 +16916,11 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { llama_cache_seq_keep(ctx, seq_id); } -llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (delta == 0) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (delta == 0) { return; } - return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } // deprecated @@ -16949,13 +16928,11 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla llama_cache_seq_add(ctx, seq_id, p0, p1, delta); } -llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (d == 1) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } - return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } // deprecated diff --git a/llama.h b/llama.h index fa6d0b58625be..bf0f4a9e140d6 100644 --- a/llama.h +++ b/llama.h @@ -562,7 +562,8 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the seq_id) + // which is only meaningful to handle for partial removals. LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, @@ -579,7 +580,8 @@ extern "C" { // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the destination seq_id) + // which is only meaningful to handle when partially copying. LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, @@ -609,8 +611,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_add( + LLAMA_API void llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -630,8 +631,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_div( + LLAMA_API void llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -652,7 +652,7 @@ extern "C" { LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From 7e13f19fb527b62ca87930841608b7369d86173a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 16:19:25 -0400 Subject: [PATCH 07/28] llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot --- llama.cpp | 586 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 307 insertions(+), 279 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3501163ba2542..969249126c186 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1753,6 +1753,9 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + // TODO: find a more compact way to add more per-layer hyper-parameters + std::vector n_head_kv_vec; + float f_norm_eps; float f_norm_rms_eps; @@ -1793,6 +1796,8 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_kv_vec != other.n_head_kv_vec) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1812,29 +1817,46 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head_kv_l(uint32_t layer) const { + if (layer < n_head_kv_vec.size()) { + int32_t n_hkv_l = n_head_kv_vec[layer]; + // TODO: what should happen when it's negative? + GGML_ASSERT(n_hkv_l >= 0); + return n_hkv_l; + } + return n_head_kv; + } + + uint32_t n_gqa(uint32_t layer = 0) const { + uint32_t n_head_kv = n_head_kv_l(layer); if (n_head_kv == 0) { return 0; } return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_v * n_head_kv; } - uint32_t n_embd_r() const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -2078,10 +2100,12 @@ struct llama_rs_cache { // computed when finding a slot uint32_t n = 0; // range of states used for the last slot - // useful to know the minimum reserved cell count per seq_id - // only counts sequences which have a non-shared tail + // only counts cells which are tails of all of their sequences. + // useful to know the minimum reserved cell count per seq_id. uint32_t n_seqs = 0; - // cells part of multiple sequences AND which have at least one tail + // cells part of multiple sequences, + // but which are only the tail of some of them. + // useful to dismiss sequences used as a shared prompt uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token @@ -2279,10 +2303,8 @@ struct llama_rs_cache { for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { llama_rs_cell & rs_cell = cells[cell_id]; if (!rs_cell.seq_nodes.empty()) { - if (rs_cell.seq_nodes.size() == 1) { - if (rs_cell.tail_rc == 1) { - n_seqs_verif += 1; - } + if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + n_seqs_verif += 1; } else if (rs_cell.tail_rc > 0) { n_shared_tail_cells_verif += 1; } @@ -2308,9 +2330,11 @@ struct llama_rs_cache { } // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - // TODO: assert the iterator points inside the correct vector + // The iterator needs to point inside the correct vector + GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); if (node_iter != rs_cell.seq_nodes.end()) { // update the tree llama_rs_seq_node node = *node_iter; @@ -2325,12 +2349,20 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + // move the tail back to the previous cell if (prev_cell.seq_nodes.size() > 1) { - if (prev_cell.tail_rc == 0) { - n_shared_tail_cells += 1; - } - if (rs_cell.seq_nodes.size() == 1) { - n_seqs -= 1; + if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf with a single tail + if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + n_seqs -= 1; + } } } prev_cell.tail_rc += 1; @@ -2341,17 +2373,22 @@ struct llama_rs_cache { if (node.is_tail()) { seq.tail = rs_cell.prev; if (rs_cell.tail_rc == 1) { - if (rs_cell.seq_nodes.size() > 1) { - // assuming the previous cell of a shared cell is also shared, - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; - } else if (seq.tail < 0) { + if (seq.tail < 0) { // no more tail, no more sequence - n_seqs -= 1; + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; + } } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; + } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { + // will fully become a tail cell + if (rs_cell.tail_rc > 0) { + n_seqs += 1; + } } if (node_iter == rs_cell.seq_nodes.begin()) { // this seq_id was the first in the list @@ -2363,14 +2400,6 @@ struct llama_rs_cache { if ((uint32_t) next_node->seq_id < seq_tails.size()) { auto & next_seq = seq_tails[next_node->seq_id]; next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_shared_tail_cells -= 1; - n_seqs += 1; - } - } } else { GGML_ASSERT(false && "invalid seq_id"); } @@ -2433,43 +2462,41 @@ struct llama_rs_cache { rs_cell.pos = prev_cell.pos + 1; rs_cell.src = prev_cell.src; } - prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; if (seq.tail == prev) { // What to do when the tail moves... - // from unique to shared (n_seqs--) - // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) - // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) - // from unique to unique (seq.n_cells++) - // from empty to unique (seq.n_cells++, n_seqs++) - // from empty to shared - // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to shared - // if the last cell has no tails (n_shared_tail_cells--) - // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to unique (seq.n_cells++) - // if this seq_id was not the first of the last cell (n_seqs++) - // if the last cell has no tails (n_shared_tail_cells--) - if (prev_cell.seq_nodes.size() > 1) { - // from shared - if (rs_cell.is_empty()) { - // to unique - if (prev_cell.seq_nodes[0].seq_id != id) { - n_seqs += 1; - } + // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) + // O -> oO (n_seqs--, n_shared_tail_cells++) + // O -> O (seq.n_cells++) + // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) + // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) + // _ -> oO (n_shared_tail_cells++) + // _ -> O (seq.n_cells++, n_seqs++) + // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) + // Oo -> OO+ (n_shared_tail_cell--) + // OOo -> O (seq.n_cells++, n_seqs++) + if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { + // from fully tail + if (prev_cell.tail_rc > 1) { + // the previous tail becomes shared with a non-tail + n_shared_tail_cells += 1; + } + if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { + // the new tail cell was previously a fully non-tail cell + n_shared_tail_cells += 1; + n_seqs -= 1; } - // the previous cell is no longer a shared tail - if (prev_cell.tail_rc == 0) { + } else if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + if (prev_cell.tail_rc == 1) { + // it was the last tail of the previous cell n_shared_tail_cells -= 1; } - } else if (!rs_cell.is_empty()) { - // from unique to shared - n_seqs -= 1; } } + prev_cell.tail_rc -= 1; } if (rs_cell.is_empty()) { // to unique @@ -2482,15 +2509,10 @@ struct llama_rs_cache { rs_cell.src = -1; } used += 1; - } else { + } else if (rs_cell.tail_rc == 0) { // to shared - if (rs_cell.seq_nodes.size() == 1) { - // a lone tail becomes a shared cell - if (rs_cell.tail_rc > 0) { - n_seqs -= 1; - } - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { + if (seq.tail < 0) { + // from empty to shared n_shared_tail_cells += 1; } } @@ -2910,26 +2932,18 @@ static bool llama_cache_init( const llama_context * ctx, ggml_type type_k, ggml_type type_v, - uint32_t n_ctx, - uint32_t n_seq_max, bool offload) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const uint32_t n_embd_r = hparams.n_embd_r(); - const uint32_t n_embd_s = hparams.n_embd_s(); - const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; - const bool has_r = n_embd_r != 0; - const bool has_s = n_embd_s != 0; + const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn; + const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0; + const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0; const bool has_rs = has_r || has_s; - const uint32_t kv_size = has_kv ? n_ctx : 0; - const uint32_t rs_size = has_rs ? n_seq_max : 0; - // TODO: per cache type layer count + const uint32_t kv_size = has_kv ? cparams.n_ctx : 0; + const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0; const int64_t n_layer = hparams.n_layer; cache.kv.size = kv_size; @@ -2967,6 +2981,7 @@ static bool llama_cache_init( std::map ctx_map; for (auto & it : buft_layer_count) { int n_layers = it.second; + // TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors struct ggml_init_params params = { /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, @@ -2995,20 +3010,20 @@ static bool llama_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); if (has_kv) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.kv.k_l.push_back(k); cache.kv.v_l.push_back(v); } if (has_r) { - ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); ggml_format_name(r, "cache_r_l%d", i); cache.rs.r_l.push_back(r); } if (has_s) { - ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); ggml_format_name(s, "cache_s_l%d", i); cache.rs.s_l.push_back(s); } @@ -3024,7 +3039,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3042,177 +3057,21 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - // FIXME: on failure, leave all caches in a consistent state. - + // only check first, to allow failing gracefully if (rs_size > 0) { - // For recurrent state architectures (like Mamba), - // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous - - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; - + // everything should fit if all seq_ids are smaller than the max for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - if ((uint32_t) seq_id < rs_size) { - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } - - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } - } - } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } - } - } - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); - - uint32_t cell_id = cache.rs.size; - bool looped_once = false; - - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - if (looped_once) { - // avoid infinite loop - // NOTE: this should not happen, but gracefully fail anyway - LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); - return false; - } - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } - } - } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } - - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; - } - cache.rs.head += 1; - } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); - } - } else { + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { // too big seq_id // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } - cache.rs.head = target_cell + 1; - } - - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; - } - - // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - if (max_seq < min_seq || max_cell < min_cell) { - return false; } } @@ -3257,7 +3116,174 @@ static bool llama_cache_find_slot( return false; } } + } + + // now modification can be done, and should NOT fail + + if (rs_size > 0) { + // For recurrent state architectures (like Mamba), + // each cache cell can store the state for a whole sequence. + // TODO: find a way to always make the rs slot contiguous + + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; + + for (uint32_t i = 0; i < n_tokens; ++i) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size + GGML_ASSERT((uint32_t) seq_id < rs_size); + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + // this seq_id has no tail (and is empty) + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } + } + + // reserve a cell for this seq_id + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + // avoid infinite loop + // NOTE: this should not fail; if it does, it's a bug. + GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } + } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + + // Assuming the tokens are in-order + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); + } + } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; + } + + // allow getting the range of used cells, from head to head + n + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; + + // sanity check + GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + } + + if (kv_size > 0) { for (uint32_t i = 0; i < n_tokens; i++) { cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; @@ -4194,9 +4220,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -4204,16 +4230,17 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); - } - + // TODO: allow ANY lossless cast // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + result.reserve(arr_info.length); + result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } @@ -4750,7 +4777,12 @@ static void llm_load_hparams( // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + + // per-layer n_head_kv + if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) { + // global/fallback n_head_kv + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + } bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -6704,10 +6736,7 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7198,8 +7227,8 @@ static void llm_build_kv_store( int64_t il) { const int64_t n_ctx = cparams.n_ctx; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(kv.size == n_ctx); @@ -7465,9 +7494,9 @@ static struct ggml_tensor * llm_build_kqv( int il) { const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head_kv = hparams.n_head_kv_l(il); const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -7619,9 +7648,7 @@ struct llm_build_context { const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; - const int64_t n_embd_k_gqa; const int64_t n_embd_head_v; - const int64_t n_embd_v_gqa; const int64_t n_expert; const int64_t n_expert_used; @@ -7673,9 +7700,7 @@ struct llm_build_context { n_head (hparams.n_head), n_head_kv (hparams.n_head_kv), n_embd_head_k (hparams.n_embd_head_k), - n_embd_k_gqa (hparams.n_embd_k_gqa()), n_embd_head_v (hparams.n_embd_head_v), - n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), @@ -7746,9 +7771,9 @@ struct llm_build_context { // we rotate only the first n_rot dimensions ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_head_kv, n_ctx, + n_embd_head_k, hparams.n_head_kv_l(il), n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -7777,6 +7802,9 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, nm, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), @@ -11014,8 +11042,8 @@ struct llm_build_context { struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); // copy states { @@ -16452,7 +16480,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -17282,7 +17310,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks @@ -17434,7 +17462,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; @@ -17627,7 +17655,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < kv_self.size; ++i) { @@ -17713,7 +17741,7 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // Write the layer count @@ -17843,7 +17871,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Sanity check model compatibility const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); if (n_layer != n_layer_ref) { LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); From cbc743e6006349dde61fe214d56c2d6efa34828d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 19:27:27 -0400 Subject: [PATCH 08/28] llama : support Jamba --- convert-hf-to-gguf.py | 103 ++++++- gguf-py/gguf/constants.py | 36 +++ gguf-py/gguf/gguf_writer.py | 7 +- gguf-py/gguf/tensor_mapping.py | 52 +++- llama.cpp | 521 ++++++++++++++++++++++++++------- 5 files changed, 601 insertions(+), 118 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index daad1c4fc7255..83d9b0638f856 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2300,7 +2300,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) @@ -2346,6 +2346,107 @@ def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i ) +@Model.register("JambaForCausalLM") +class JambaModel(Model): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield new_name, data_torch + + # same as Mamba + def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del n_dims # unused + + return bid is not None and new_name in ( + self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ + gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.SSM_X, + gguf.MODEL_TENSOR.SSM_DT, + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ] + ) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 42df2e4d00604..3668778be0af1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() STARCODER2 = auto() MAMBA = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -180,7 +181,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_OUT = auto() @@ -214,6 +218,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -259,7 +264,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", } @@ -678,6 +686,34 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8b41b54eaa5a6..272ef4a8071cd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -385,8 +385,11 @@ def add_parallel_residual(self, use: bool) -> None: def add_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) - def add_head_count_kv(self, count: int) -> None: - self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + def add_head_count_kv(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) def add_key_length(self, length: int) -> None: self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8e1cac9152f55..eb60bb8ac01d4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -206,6 +206,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "model.layers.{bid}.pre_ff_layernorm", # jamba ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -214,6 +215,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.feed_forward.router", # jamba ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -244,6 +246,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "model.layers.{bid}.feed_forward.up_proj", # jamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -272,6 +275,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc12", # nomic-bert "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.feed_forward.gate_proj", # jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -306,6 +310,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "model.layers.{bid}.feed_forward.down_proj", # jamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -347,38 +352,57 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba ), } diff --git a/llama.cpp b/llama.cpp index 969249126c186..3176c8d0d5d64 100644 --- a/llama.cpp +++ b/llama.cpp @@ -221,6 +221,7 @@ enum llm_arch { LLM_ARCH_GEMMA, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -257,6 +258,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -472,7 +474,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, }; @@ -970,6 +975,37 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1905,6 +1941,9 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * ssm_dt_norm; + struct ggml_tensor * ssm_b_norm; + struct ggml_tensor * ssm_c_norm; // attention struct ggml_tensor * wq; @@ -5150,6 +5189,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6854,6 +6909,118 @@ static bool llm_load_tensors( layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + GGML_ASSERT((int64_t) hparams.n_head_kv_vec.size() == n_layer); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv_vec[i]; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + layer.ssm_b_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}); + layer.ssm_c_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + + layer.wq = nullptr; + layer.wk = nullptr; + layer.wv = nullptr; + layer.wo = nullptr; + + } else { + // Attention layers + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ssm_in = nullptr; + layer.ssm_conv1d = nullptr; + layer.ssm_conv1d_b = nullptr; + layer.ssm_x = nullptr; + layer.ssm_dt_norm = nullptr; + layer.ssm_dt = nullptr; + layer.ssm_dt_b = nullptr; + layer.ssm_b_norm = nullptr; + layer.ssm_c_norm = nullptr; + layer.ssm_a = nullptr; + layer.ssm_d = nullptr; + layer.ssm_out = nullptr; + } + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + + layer.ffn_gate = nullptr; + layer.ffn_down = nullptr; + layer.ffn_up = nullptr; + } else { + // FFN (no MoE) + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.ffn_gate_exps = nullptr; + layer.ffn_down_exps = nullptr; + layer.ffn_up_exps = nullptr; + } + } + } break; case LLM_ARCH_XVERSE: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -7632,6 +7799,132 @@ static struct ggml_tensor * llm_build_kv( return cur; } +// TODO: split +static struct ggml_tensor * llm_build_mamba( + struct ggml_context * ctx, + const llama_model & model, + const llama_hparams & hparams, + const llama_rs_cache & rs, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + struct ggml_tensor * state_seq, + struct ggml_tensor * w_dt_norm, + struct ggml_tensor * w_b_norm, + struct ggml_tensor * w_c_norm, + int32_t n_tokens, + int32_t rs_head, + int32_t n_rs, + const llm_build_cb & cb, + int il) { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); + } + + // clear states of sequences which are starting at the beginning of this batch + { + conv_states = ggml_mul(ctx, conv_states, state_mask); + ssm_states = ggml_mul(ctx, ssm_states, state_mask); + } + + conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + + // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_tokens} + struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); + struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + + // conv + { + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + + // extract x from x_conv + x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + + // bias + x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + + if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + + // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, + // because only a single tensor can be returned. + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + + // store last states (the second part of y_ssm_states) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + + struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + + // TODO: skip computing output for unused tokens + + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx, y, ggml_silu(ctx, z)); + + // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); + } + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -11024,13 +11317,6 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - const int64_t d_model = n_embd; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -11042,116 +11328,144 @@ struct llm_build_context { struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); - // copy states - { - // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows - // NOTE: assuming the copy destinations are ALL contained in the current batch - // this shrinks the tensors's ne[1] to n_rs - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - } + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, NULL, NULL, NULL, + n_tokens, rs_head, n_rs, cb, il); - // clear states of sequences which are starting at the beginning of this batch - { - conv_states = ggml_mul(ctx0, conv_states, state_mask); - ssm_states = ggml_mul(ctx0, ssm_states, state_mask); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_rs); + // residual + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_jamba() { + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * state_seq = build_inp_s_seq(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv_l(il); - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); - // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + if (n_head_kv == 0) { + // Mamba + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, + model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, + n_tokens, rs_head, n_rs, cb, il); + } else { + // Attention - // conv - { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - - // extract x from x_conv - x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); - - // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); - - x = ggml_silu(ctx0, x); - } - - // ssm - { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); - // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); - - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); - - struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - x = ggml_get_rows(ctx0, x, inp_out_ids); - y = ggml_get_rows(ctx0, y, inp_out_ids); - z = ggml_get_rows(ctx0, z, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + // No RoPE :) + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } // residual - cur = ggml_add(ctx0, cur, inpL); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + cb, il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); cb(cur, "l_out", il); // input for next layer @@ -11630,6 +11944,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(); } break; + case LLM_ARCH_JAMBA: + { + result = llm.build_jamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -16644,6 +16962,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: return LLAMA_ROPE_TYPE_NONE; From 61a88a1da399be2207c8aa0a8a280dffc3f64887 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 22:41:38 -0400 Subject: [PATCH 09/28] llama : fix BERT inference without KV cache --- llama.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llama.cpp b/llama.cpp index 6bc5167be6f60..678c49094b22e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3105,6 +3105,10 @@ static bool llama_cache_init( ggml_context * ctx = it.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { + if (!has_kv && !has_rs) { + // no buffer was needed, so this is fine + return true; + } LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); return false; } From ea2e63e9d2b4d9e60587083b9fc824d9ca342af1 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 25 May 2024 12:54:30 -0400 Subject: [PATCH 10/28] convert-hf : check for unprocessed Jamba experts --- convert-hf-to-gguf.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 971875069dcc3..28a43c54f70f7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2470,6 +2470,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield new_name, data_torch + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + # same as Mamba def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: del n_dims # unused From fc59407efea1d49a3d8338fd20fa38afbe06fdb5 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 25 May 2024 13:55:11 -0400 Subject: [PATCH 11/28] convert-hf : support Mini-Jamba conversion --- convert-hf-to-gguf.py | 21 ++++++++++++++++++++- gguf-py/gguf/tensor_mapping.py | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 28a43c54f70f7..a42458e63d23f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2393,6 +2393,16 @@ def get_vocab_base_pre(self, tokenizer) -> str: return "gpt-2" + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 @@ -2412,7 +2422,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.block_count) - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) @@ -2430,6 +2440,15 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + # process the experts separately if ".feed_forward.experts." in name: n_experts = self.hparams["num_experts"] diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b71bf1ecdd4d4..c81600151b142 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -207,6 +207,7 @@ class TensorNameMap: "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -390,10 +391,12 @@ class TensorNameMap: MODEL_TENSOR.SSM_B_NORM: ( "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_C_NORM: ( "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_D: ( From 181dadf294d9495b54a86a23299fc15b282dac1d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 28 May 2024 12:23:05 -0400 Subject: [PATCH 12/28] llama : fix Jamba quantization sanity checks --- llama.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 678c49094b22e..4c9ecf018e67f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16290,11 +16290,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks - // - // - qs.n_attention_wv == 0 for Mamba models - // - qs.n_attention_wv == model.hparams.n_layer for Transformer models - // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + { + const auto & n_head_kv_vec = model.hparams.n_head_kv_vec; + int n_attn_layer; + if (model.hparams.n_head_kv == 0) { + // Mamba models don't have attention layers + n_attn_layer = 0; + } else { + // Transformers and hybrid models (like Jamba) have attention layers + n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_vec.begin(), n_head_kv_vec.end(), 0); + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } size_t total_size_org = 0; size_t total_size_new = 0; From 3a414b0be242be52f8c186acb368510975eb0d15 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 28 May 2024 12:21:52 -0400 Subject: [PATCH 13/28] llama : sequence-length-aware batch splitting --- llama.cpp | 443 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 355 insertions(+), 88 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4c9ecf018e67f..209d3063cb5ec 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2807,6 +2807,321 @@ struct llama_model { } }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + + int32_t n_tokens; + int32_t n_seqs; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * output; +}; + +struct llama_sbatch_seq { + int32_t n_seq_id; + llama_seq_id * seq_id; + size_t offset; + size_t length; + + // helper for smoother batch API transition -- can be deprecated in the future + llama_seq_id all_seq_id; // used if seq_id == NULL +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + true, + 0, + 0, + !has_embd ? ubatch_token.data() : nullptr, + has_embd ? ubatch_embd.data() : nullptr, + ubatch_pos.data(), + ubatch_n_seq_id.data(), + ubatch_seq_id.data(), + ubatch_output.data(), + }; + return ubatch; + } + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + if (ubatch.equal_seqs) { + // is the new sequence of a different size than expected? + if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) { + ubatch.equal_seqs = false; + } + } + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + ubatch.embd = nullptr; + } + // from here on, the else branches are deprecated; + // they are helpers for smoother batch API transition + if (batch->pos) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + llama_pos bi = ids[seq.offset + i]; + ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + } + } + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_tokens + i] = 1; + } + } + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id; + } + } + if (batch->logits) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + ubatch.n_tokens += length; + ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + } + + // legacy split, unknown number of sequences of unequal lengths + llama_ubatch split_slice(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + // TODO: reduce copies + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; + } + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (legacy_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + s.all_seq_id = batch.all_seq_id; + return; + } + std::sort(ids.begin(), ids.end(), + [batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id (assuming batch.all_pos_1 is positive) + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const size_t s_len = seq.size(); + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; + seq.push_back(new_seq); + last_seq = &seq[s_len]; + } + } else { + llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + seq.push_back(new_seq); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); + } +}; + struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { @@ -2832,6 +3147,9 @@ struct llama_context { // key + value cache for self-attention, and/or recurrent state cache struct llama_cache cache; + // sequence-length-aware batch splitting + llama_sbatch sbatch; + std::mt19937 rng; bool has_evaluated_once = false; @@ -3126,7 +3444,7 @@ static bool llama_cache_init( // to the first cell of the slot. static bool llama_cache_find_slot( struct llama_cache & cache, - const struct llama_batch & batch) { + const struct llama_ubatch & batch) { const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; @@ -7533,7 +7851,7 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_batch & batch, + const llama_ubatch & batch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; @@ -8107,7 +8425,7 @@ struct llm_build_context { llama_context & lctx; const llama_hparams & hparams; const llama_cparams & cparams; - const llama_batch & batch; + const llama_ubatch & batch; const llama_kv_cache & kv_self; const llama_rs_cache & rs_self; @@ -8153,7 +8471,7 @@ struct llm_build_context { // TODO: consider making the entire interface noexcept llm_build_context( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, const llm_build_cb & cb, bool worst_case) : model (lctx.model), @@ -12215,8 +12533,8 @@ struct llm_build_context { }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -12232,8 +12550,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const } static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -12250,7 +12568,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, bool worst_case) { const auto & model = lctx.model; @@ -12438,7 +12756,7 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // // set input data // @@ -12478,10 +12796,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.logits) { + } else if (batch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.logits[i]) { + if (batch.output[i]) { data[n_outputs++] = i; } } @@ -12835,11 +13153,6 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - // count outputs if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -12852,55 +13165,29 @@ static int llama_decode_internal( n_outputs = 1; } + lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); return -2; }; - // set output mappings - if (batch_all.logits) { - int32_t i_logits = 0; - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { - lctx.output_ids[i] = i_logits++; - } - } - } else { - for (uint32_t i = 0; i < n_outputs; ++i) { - lctx.output_ids[i] = i; - } - } - - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - llama_batch u_batch = { - /* .n_tokens = */ (int32_t) n_tokens, - /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - /* .all_pos_1 = */ batch_all.all_pos_1, - /* .all_seq_id = */ batch_all.all_seq_id, - }; + while (lctx.sbatch.n_tokens > 0) { + // TODO: deprecate slice splits in favor of equal splits + llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); + const uint32_t n_tokens = u_batch.n_tokens; // count the outputs in this u_batch { int32_t n_outputs_new = 0; - if (u_batch.logits) { - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.logits[i] != 0; - } - } else if (n_outputs == n_tokens_all) { + if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { - // keep last output only - if (cur_token + n_tokens >= n_tokens_all) { - n_outputs_new = 1; + GGML_ASSERT(u_batch.output); + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += u_batch.output[i] != 0; } } @@ -12911,32 +13198,6 @@ static int llama_decode_internal( int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (u_batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; - } - - u_batch.pos = pos.data(); - } - - if (u_batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = u_batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - u_batch.n_seq_id = n_seq_id.data(); - u_batch.seq_id = seq_id_arr.data(); - } - // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); @@ -12945,6 +13206,7 @@ static int llama_decode_internal( return 1; } + // TODO: move into llama_cache_find_slot if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -13108,6 +13370,12 @@ static int llama_decode_internal( #endif } + // set output mappings + GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs); + for (size_t i = 0; i < n_outputs; ++i) { + lctx.output_ids[lctx.sbatch.out_ids[i]] = i; + } + // set to total number of outputs in the batch, for use in llama_get_logits_ith lctx.n_outputs = n_outputs; @@ -13398,10 +13666,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph + int n_seqs = 1; // TODO: worst-case number of sequences int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); - int n_past = lctx.cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(lctx.sched); @@ -17345,10 +17614,11 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph + int n_seqs = 1; // TODO: worst-case number of sequences int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); - int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph if (!ggml_backend_sched_reserve(ctx->sched, gf)) { @@ -18662,8 +18932,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Allocate the new cells for the slot if (cell_count) { - llama_batch batch = llama_batch_init(cell_count, 0, 1); + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; memcpy(&pos, inp, sizeof(pos)); @@ -18674,7 +18945,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.seq_id[i][0] = dest_seq_id; } if (!llama_cache_find_slot(cache, batch)) { - llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; } @@ -18686,9 +18956,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - - // Cleanup - llama_batch_free(batch); } const uint32_t kv_size = kv_self.size; From 3587a9498773203f10f66814f67568797f1ce7a0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 11:37:14 -0400 Subject: [PATCH 14/28] llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch --- ggml.c | 250 ++++++--------- ggml.h | 6 +- llama.cpp | 946 ++++++++++++++++++++++++++++++++++-------------------- 3 files changed, 699 insertions(+), 503 deletions(-) diff --git a/ggml.c b/ggml.c index 58ac9702694c7..7a3a5fa9468ff 100644 --- a/ggml.c +++ b/ggml.c @@ -7103,40 +7103,35 @@ struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq) { + struct ggml_tensor * c) { GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_3d(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_vector(sq)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_tokens = x->ne[1]; - const int64_t n_rs = s->ne[2]; + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = x->ne[1]; // tokens per sequence + const int64_t n_s = s->ne[2]; - GGML_ASSERT( s->ne[0] == d_conv - 1); - GGML_ASSERT( s->ne[1] == d_inner); - GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_tokens); + GGML_ASSERT(s->ne[0] == d_conv - 1); + GGML_ASSERT(s->ne[1] == d_inner); + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(x->ne[2] == n_s); bool is_node = false; - if (s->grad || x->grad || c->grad || sq->grad) { + if (s->grad || x->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = c; - result->src[3] = sq; return result; } @@ -7150,40 +7145,43 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq) { + struct ggml_tensor * C) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tokens = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tokens); - GGML_ASSERT(C->ne[0] == d_state); - GGML_ASSERT(C->ne[1] == n_tokens); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); } bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + // y + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7193,7 +7191,6 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = sq; return result; } @@ -16249,24 +16246,20 @@ static void ggml_compute_forward_ssm_conv_f32( const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight - const struct ggml_tensor * src3 = dst->src[3]; // state_seq const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv - const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // n_tokens - const int n_rs = src0->ne[2]; // max number of sequences in the batch + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // tokens per sequence + const int n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // for use with the destination state offset between sequences - GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16276,64 +16269,53 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src3->data; // {n_tokens} + // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? + // This would avoid having to copy into an intermediate buffer, but the state would be bigger. + float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - if (n_rs > 1) { - // multiple sequences means it's hard to know when it's the first time a state is read, - // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); - // can't use memcpy because of d_conv vs d_conv - 1 - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - // copy s0 to last (d_conv - 1) columns of s - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } + for (int i3 = 0; i3 < n_s; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + + // copy the state into working memory + // can't use memcpy because (d_conv) != (d_conv - 1) + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; } } - } - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} - float * s0; // {d_conv - 1, d_inner, n_rs} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - int ne0s0; + for (int i2 = 0; i2 < n_t; ++i2) { + float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); + // shift state left + memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} - ne0s0 = src0->ne[0]; - } else { - // the source is the last (d_conv - 1) columns of the destination - s0 = s + 1; - ne0s0 = nc; - } + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // shift state left - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; } - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; } - // it seems a little faster when this is separate from the state shift + // copy the state out of it for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - float sumf = 0.0f; - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + for (int i0 = 0; i0 < nc - 1; ++i0) { + s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; } - x[i1] = sumf; } } } @@ -16368,30 +16350,24 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - const struct ggml_tensor * src6 = dst->src[6]; // sq const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // required for the dot product between s and C, and when copying the states + // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[2]) - GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16401,55 +16377,33 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src6->data; // {n_tokens} - - if (n_rs > 1) { - // it's hard to know if the source states have already been copied - // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); - memcpy(s, s0, nc*ir*sizeof(float)); - } - } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} - } else { - // otherwise the source is the same as the destination - s0 = s; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } - y[i1] = sumf; } } } @@ -19614,7 +19568,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_SSM_CONV: + { + const int64_t d_conv = node->src[2]->ne[0]; + const int64_t d_inner = node->src[0]->ne[1]; + cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); + } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 4e6bcb30fd931..bdf05a31139e5 100644 --- a/ggml.h +++ b/ggml.h @@ -1793,8 +1793,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq); + struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, @@ -1803,8 +1802,7 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq); + struct ggml_tensor * C); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 27374c18506c9..ca64b7e29df7a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2114,6 +2114,24 @@ struct llama_layer { struct ggml_tensor * rope_short = nullptr; }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + // FIXME: make all uses of this use n_seqs + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; @@ -2223,17 +2241,15 @@ struct llama_rs_cell { } }; - struct llama_rs_seq_meta { // cell id of the latest state of this seq_id int32_t tail = -1; // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // changing the tail cell of a sequence can only be done at batch boundary, - // this guards against changing the cell when it shouldn't be; - // should be cleared when done finding a slot - bool in_ubatch = false; + // the last pos of this sequence if it is in the current ubatch, + // only set and used when finding a slot. + llama_pos ubatch_end_pos = -1; }; // ring-buffered tree of cached recurrent state data @@ -2261,6 +2277,10 @@ struct llama_rs_cache { // find tail cells faster std::vector seq_tails; // map seq_ids to cell ids + // freeable cell ids, computed when finding a slot + // useful to find the smallest range to defrag + std::vector freeable; + // per layer // NOTE: the naming of r and s is arbitrary std::vector r_l; // rolling/shift states @@ -2399,8 +2419,8 @@ struct llama_rs_cache { if (seq_node->next_cell != next) { // TODO: relax the error when multiple cells have the same pos if (debug) { - LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", - __func__, cell_id, seq_node->next_cell, next); + LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", + __func__, seq_id, cell_id, seq_node->next_cell, next); } seq_node->next_cell = next; was_valid = false; @@ -2414,15 +2434,6 @@ struct llama_rs_cache { } seq.n_cells = n_cells; } - // in_batch should only be true when in the process of finding a slot - if (seq.in_ubatch != false) { - if (debug) { - LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", - __func__, seq_id); - } - seq.in_ubatch = false; - was_valid = false; - } } // tail_rc for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { @@ -2475,6 +2486,88 @@ struct llama_rs_cache { return was_valid; } + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + uint32_t min_cells_per_seq(const llama_ubatch & batch) const { + uint32_t seqs = n_seqs; + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; + } + } + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); + } + + void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { + GGML_ASSERT(batch.equal_seqs); + int32_t min_cells = min_cells_per_seq(batch); + + // TODO: minimize work required to find freeable cells + // currently, this finds freeable cells by excluding non-freeable cells, + // because some conditions are more easily expressed this way. + + freeable.assign(size, 1); + + for (llama_rs_seq_meta & seq : seq_tails) { + seq.ubatch_end_pos = -1; + } + + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); + llama_rs_seq_meta & seq = seq_tails[seq_id]; + seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; + } + } + + for (llama_rs_seq_meta & seq : seq_tails) { + if (seq.tail >= 0 && freeable[seq.tail] != 0) { + llama_pos end_pos = seq.ubatch_end_pos; + // When is a tail cell not freeable? + if (end_pos < 0) { + // when any of its tails are not in the batch + freeable[seq.tail] = 0; + } else if (min_cells > 1) { + // TODO: fallback to this less often + llama_rs_cell & tail = cells[seq.tail]; + GGML_ASSERT(tail.pos < end_pos); + if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { + // make a checkpoint before prompt processing + // TODO: should it always be done after instead? + freeable[seq.tail] = 0; + } else { + llama_rs_cell & prev = cells[tail.prev]; + if (prev.pos + checkpoint_interval <= end_pos) { + // make a checkpoint during text generation + freeable[seq.tail] = 0; + } + } + } + } + } + + for (uint32_t i = 0; i < size; ++i) { + llama_rs_cell & cell = cells[i]; + if (!cell.is_empty() && cell.tail_rc == 0) { + // TODO: reduce indirection here + llama_rs_seq_node & seq_node = cell.seq_nodes[0]; + llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; + bool keep_tail = freeable[seq.tail] == 0; + // kept tails use an additional cell, so make them allow freeing a checkpoint + int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; + // A checkpoint is kept if there's enough alloted space for this sequence + // or if it's the state right before the tail + if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { + freeable[i] = 0; + } + } + } + } + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { @@ -2496,22 +2589,30 @@ struct llama_rs_cache { prev_node->next_cell = node.next_cell; if (node.is_tail()) { // move the tail back to the previous cell + prev_cell.tail_rc += 1; if (prev_cell.seq_nodes.size() > 1) { if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { - if (prev_cell.tail_rc == 0) { + if (prev_cell.tail_rc == 1) { n_shared_tail_cells += 1; } - // o oo oo - // |/ -> o/ - // | | - // e.g. when removing the leaf with a single tail - if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { - n_seqs -= 1; + if (rs_cell.tail_rc == 1) { + if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf of a split tree + n_seqs -= 1; + } else { + // o + // o -> oo + // | | + // e.g. when merging back with a previous tail + n_shared_tail_cells -= 1; + } } } } - prev_cell.tail_rc += 1; } } if ((uint32_t) node.seq_id < seq_tails.size()) { @@ -2534,6 +2635,7 @@ struct llama_rs_cache { // will fully become a tail cell if (rs_cell.tail_rc > 0) { n_seqs += 1; + n_shared_tail_cells -= 1; } } if (node_iter == rs_cell.seq_nodes.begin()) { @@ -2583,14 +2685,107 @@ struct llama_rs_cache { return false; } - bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { + bool swap_cells(uint32_t i_src, uint32_t i_dst) { + if (i_src < size && i_dst < size && i_src != i_dst) { + llama_rs_cell & src = cells[i_src]; + llama_rs_cell & dst = cells[i_dst]; + + for (llama_rs_seq_node & seq_node : src.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_dst; + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_dst; + } + } + for (llama_rs_seq_node & seq_node : dst.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_src; + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_src; + } + } + + if (src.prev == dst.prev) { + // avoid swapping them twice + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } else if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } else { + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } + } + if (dst.prev >= 0) { + llama_rs_cell & prev = cells[dst.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } + + std::swap(src.pos, dst.pos); + std::swap(src.src, dst.src); + std::swap(src.prev, dst.prev); + std::swap(src.tail_rc, dst.tail_rc); + std::swap(src.seq_nodes, dst.seq_nodes); + + return true; + } + return false; + } + + bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; int32_t prev = rs_cell.prev; + if (end_pos >= 0) { + if (end_pos <= rs_cell.pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, end_pos, rs_cell.pos, id); + } + rs_cell.pos = end_pos; + } else { + // if no pos was specified, then the target cell should already have a valid one. + GGML_ASSERT(!rs_cell.is_empty()); + } if ((uint32_t) seq.tail == i_cell) { // the cell is already the tail of this seq_id - return false; + if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { + GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id + // remove non-tail seq_ids (branch off them) + for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { + if (!rs_cell.seq_nodes[i].is_tail()) { + remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); + } + } + } + return true; } if (rs_cell.is_empty()) { prev = seq.tail; @@ -2603,9 +2798,7 @@ struct llama_rs_cache { auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken - if (rs_cell.pos < 0) { - GGML_ASSERT(rs_cell.is_empty()); - rs_cell.pos = prev_cell.pos + 1; + if (rs_cell.is_empty()) { rs_cell.src = prev_cell.src; } prev_node->next_cell = i_cell; @@ -2650,8 +2843,7 @@ struct llama_rs_cache { if (seq.tail < 0) { // from empty to unique n_seqs += 1; - // pos was not yet set - rs_cell.pos = 0; + // make sure it's cleared rs_cell.src = -1; } used += 1; @@ -2671,16 +2863,6 @@ struct llama_rs_cache { return false; } - // each seq_id should have access to at least this many cells - // (to use when pruning (to avoid over-pruning)) - size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { - uint32_t seqs = n_seqs; - if (new_seq.tail < 0 || new_seq.n_cells == 0) { - seqs += 1; - } - return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); - } - size_t total_size() const { size_t size = 0; for (struct ggml_tensor * r : r_l) { @@ -2883,22 +3065,6 @@ struct llama_model { } }; -// very similar to llama_batch, -// but has more metadata about sequences -struct llama_ubatch { - bool equal_seqs; - - int32_t n_tokens; - int32_t n_seqs; - - llama_token * token; - float * embd; - llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * output; -}; - struct llama_sbatch_seq { int32_t n_seq_id; llama_seq_id * seq_id; @@ -2954,6 +3120,7 @@ struct llama_sbatch { true, 0, 0, + 0, !has_embd ? ubatch_token.data() : nullptr, has_embd ? ubatch_embd.data() : nullptr, ubatch_pos.data(), @@ -2967,16 +3134,14 @@ struct llama_sbatch { void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { GGML_ASSERT(batch != nullptr); GGML_ASSERT(length <= seq.length); - if (ubatch.equal_seqs) { - // is the new sequence of a different size than expected? - if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) { - ubatch.equal_seqs = false; - } - } + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; } } else { ubatch.token = nullptr; @@ -3004,22 +3169,32 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (batch->n_seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]]; + if (seq.n_seq_id > 0) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } else { + GGML_ASSERT(seq.n_seq_id == 1); + ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = 1; - } - } - if (batch->seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]]; + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } } - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id; + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; + } } } if (batch->logits) { @@ -3043,11 +3218,15 @@ struct llama_sbatch { if (is_last) { out_ids.push_back(id); } } } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits + ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits seq.offset += length; seq.length -= length; n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } // legacy split, unknown number of sequences of unequal lengths @@ -3283,7 +3462,6 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -3426,6 +3604,7 @@ static bool llama_cache_init( cache.rs.cells.resize(rs_size); cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(rs_size); + cache.rs.freeable.reserve(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -3524,11 +3703,13 @@ static bool llama_cache_find_slot( const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_seqs = batch.n_seqs; + const uint32_t n_seq_tokens = batch.n_seq_tokens; // only check first, to allow failing gracefully if (rs_size > 0) { // everything should fit if all seq_ids are smaller than the max - for (uint32_t i = 0; i < n_tokens; ++i) { + for (uint32_t i = 0; i < n_seqs; ++i) { int32_t n_seq_id = batch.n_seq_id[i]; for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; @@ -3541,6 +3722,23 @@ static bool llama_cache_find_slot( } } } + // TODO: configurable checkpoint interval + cache.rs.freeable_for_batch(batch, 8); + { + uint32_t freeable_rs_cell_count = 0; + for (uint32_t is_freeable : cache.rs.freeable) { + freeable_rs_cell_count += (uint32_t) (is_freeable != 0); + if (freeable_rs_cell_count >= n_seqs) { + // there's enough, no need to count them all + break; + } + } + if (n_seqs > freeable_rs_cell_count) { + // This should not happen + LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); + return false; + } + } } if (kv_size > 0) { @@ -3591,172 +3789,146 @@ static bool llama_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous + // A slot should be always be contiguous. - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; + uint32_t min_head = 0; + uint32_t min_n = cache.rs.size; + uint32_t min_free = 0; - for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - GGML_ASSERT((uint32_t) seq_id < rs_size); - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } - - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - // this seq_id has no tail (and is empty) - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } + // compact the freeable cell list + // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 + // while also finding the smallest cell range for the slot + { + uint32_t next_free = 0; + for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { + if (cache.rs.freeable[i]) { + cache.rs.freeable[next_free] = i; + next_free += 1; + + if (next_free >= n_seqs) { + uint32_t head = cache.rs.freeable[next_free - n_seqs]; + // i is the last seen freeable cell id + uint32_t n = i - head + 1; + // keep the first smallest big enough slot + if (n < min_n) { + min_free = next_free - n_seqs; + min_head = head; + min_n = n; + if (n == n_seqs) { + // it's the smallest it can be + break; } } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } } } + } + } - // reserve a cell for this seq_id - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + // sanity check + GGML_ASSERT(min_head + min_n <= cache.rs.size); - uint32_t cell_id = cache.rs.size; - bool looped_once = false; + // keep only the necessary range + cache.rs.freeable.resize(min_free + n_seqs); + cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); + GGML_ASSERT(cache.rs.freeable.size() == n_seqs); + GGML_ASSERT(min_n >= n_seqs); + cache.rs.freeable.resize(min_n); - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - // avoid infinite loop - // NOTE: this should not fail; if it does, it's a bug. - GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } + // expand the free list + // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 + for (uint32_t i = n_seqs; i-- > 0;) { + uint32_t dst = cache.rs.freeable[i] - min_head; + if (dst != i) { + cache.rs.freeable[i] = 0; + } + GGML_ASSERT(dst >= i); + cache.rs.freeable[dst] = 1; + } + + // coalesce the free cells together + // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 + // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 + { + uint32_t top_free = min_n - 1; + for (uint32_t i = min_n; i-- > 1;) { + uint32_t is_free = cache.rs.freeable[i]; + if (!is_free) { + GGML_ASSERT(top_free > i); + cache.rs.swap_cells(min_head + i, min_head + top_free); + std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); + // the previous one has to be free, + // otherwise it would already have been swapped. + top_free -= 1; + } + // stop early if all freeable cells have already been put at the beginning + if (top_free < n_seqs) { break; } + } + } + + // order the re-used cells identically to their batch order + // (and clear the non-reused cells) + { + for (uint32_t i = 0; i < n_seqs; ++i) { + // ignore the already-swapped cells + if (cache.rs.freeable[i]) { + llama_rs_cell & cell = cache.rs.cells[min_head + i]; + if (!cell.is_empty()) { + if (cell.tail_rc == 0) { + cache.rs.clear_cell(cell); + } else { + // TODO: does this always work correctly + // even if there are more than one seq_node in this cell? + + // Which seq_id of the batch is it? + llama_seq_id seq_id = cell.seq_nodes[0].seq_id; + int32_t nth_seq_id = -1; + for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { + if (seq_id == batch.seq_id[s][0]) { + nth_seq_id = s; + break; } } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } + GGML_ASSERT(nth_seq_id != -1); - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; + cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); + cache.rs.freeable[i] = 0; + std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); + i -= 1; // check this cell again, now that it was swapped } - cache.rs.head += 1; } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); } } - cache.rs.head = target_cell + 1; } - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; + // reserve + { + for (uint32_t i = 0; i < n_seqs; ++i) { + uint32_t i_cell = min_head + i; + int32_t n_seq_id = batch.n_seq_id[i]; + llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; + // set the pos with the first seq_id + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); + // insert the rest of the seq_ids by re-using the cell's pos + for (int j = 1; j < n_seq_id; ++j) { + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); + } + } } // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + cache.rs.head = min_head; + cache.rs.n = min_n; } if (kv_size > 0) { - for (uint32_t i = 0; i < n_tokens; i++) { - cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } @@ -8492,16 +8664,15 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_ubatch & batch, const llama_rs_cache & rs, struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, struct ggml_tensor * state_mask, - struct ggml_tensor * state_seq, struct ggml_tensor * w_dt_norm, struct ggml_tensor * w_b_norm, struct ggml_tensor * w_c_norm, - int32_t n_tokens, int32_t rs_head, int32_t n_rs, const llm_build_cb & cb, @@ -8510,14 +8681,23 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + struct ggml_tensor * conv_states_all = rs.r_l[il]; + struct ggml_tensor * ssm_states_all = rs.s_l[il]; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, conv_states_all, hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, ssm_states_all, hparams.n_embd_s(il), rs.size); // copy states { - // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows - // NOTE: assuming the copy destinations are ALL contained in the current batch + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs // this shrinks the tensors's ne[1] to n_rs conv_states = ggml_get_rows(ctx, conv_states, state_copy); ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); @@ -8532,17 +8712,24 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); + struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); // conv { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. + // And also because this uses much less memory for large batches (4 times less when d_conv is 4). + // The equivalent is to concatenate the columns of conv_states and x, // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, // then element-wise multiply that with the conv1d weigth, // then sum the elements of each row, @@ -8551,17 +8738,17 @@ static struct ggml_tensor * llm_build_mamba( // and then you're left with the resulting x tensor. // The new conv_states is the last (d_conv - 1) columns // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + // ensure conv is updated before copying into the recurrent state cache + ggml_build_forward_expand(graph, x); - // extract x from x_conv - x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + ggml_build_forward_expand(graph, + ggml_cpy(ctx, conv_states, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_rs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8571,45 +8758,47 @@ static struct ggml_tensor * llm_build_mamba( // ssm { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); // split - struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache + ggml_build_forward_expand(graph, y); - struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, ssm_states, + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - // TODO: skip computing output for unused tokens + // TODO: skip computing output earlier for unused tokens - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, z)); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); } + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + return cur; } @@ -8642,6 +8831,8 @@ struct llm_build_context { const float norm_eps; const float norm_rms_eps; + const int32_t n_seqs; + const int32_t n_seq_tokens; const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_rs; @@ -8692,6 +8883,8 @@ struct llm_build_context { beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), + n_seqs (batch.n_seqs), + n_seq_tokens (batch.n_seq_tokens), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_rs (worst_case ? rs_self.size : rs_self.n), @@ -8726,7 +8919,6 @@ struct llm_build_context { lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; } void free() { @@ -8898,13 +9090,6 @@ struct llm_build_context { return lctx.inp_s_mask; } - struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(lctx.inp_s_seq, "inp_s_seq", -1); - ggml_set_input(lctx.inp_s_seq); - return lctx.inp_s_seq; - } - struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12017,7 +12202,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { // norm @@ -12026,9 +12210,9 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, NULL, NULL, NULL, - n_tokens, rs_head, n_rs, cb, il); + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, + state_copy, state_mask, NULL, NULL, NULL, + rs_head, n_rs, cb, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -12074,7 +12258,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -12089,10 +12272,9 @@ struct llm_build_context { if (n_head_kv == 0) { // Mamba - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, state_copy, state_mask, model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, - n_tokens, rs_head, n_rs, cb, il); + rs_head, n_rs, cb, il); } else { // Attention @@ -12152,6 +12334,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -13234,8 +13417,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (lctx.inp_KQ_mask) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13245,22 +13430,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -fabs(kv_self.cells[i].pos - pos); + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = batch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } + data[h*(n_kv*n_seq_tokens*n_seqs) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } @@ -13271,8 +13459,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } else { + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13280,27 +13470,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -fabs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = batch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < batch.n_seq_id[s0]; ++s) { + if (batch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -fabs(batch.pos[ti] - batch.pos[tj]); + } else { + f = 0.0f; + } + break; + } } - break; + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; } } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } } } } @@ -13308,7 +13506,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -13317,12 +13517,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; + sum[seq_id] += batch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -13333,14 +13535,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - data[seq_id*n_tokens + i] = div[seq_id]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } } } if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -13348,14 +13555,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t * data = (uint32_t *) lctx.inp_cls->data; memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { - data[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } } } } @@ -13372,7 +13583,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t cell_id = i + rs_self.head; llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) rs_cell.src >= 0; + data[i] = (float) (rs_cell.src >= 0); // only clear once if (rs_cell.src < 0) { @@ -13404,29 +13615,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } - - // For Mamba (and other recurrent architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Each row contains relative cell ids of the sequences for the associated token. - // Like with the KQ_mask, if a token in the batch has multiple sequences, - // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). - if (lctx.inp_s_seq) { - const int64_t n_tokens = batch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_seq->data; - - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[i] = seq.tail - rs_self.head; - } - } } } @@ -13598,7 +13786,7 @@ static int llama_decode_internal( } else { GGML_ASSERT(u_batch.output); for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.output[i] != 0; + n_outputs_new += (int32_t) (u_batch.output[i] != 0); } } @@ -14077,10 +14265,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -18026,10 +18214,10 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -19347,6 +19535,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, if (cell_count) { llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; @@ -19354,9 +19543,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(pos); batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; if (!llama_cache_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; @@ -19680,9 +19869,54 @@ void llama_synchronize(struct llama_context * ctx) { ctx->t_compute_start_us = 0; } +// make the outputs have the same order they had in the user-provided batch +static void llama_reorder_outputs(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + std::vector logits_tmp; + std::vector embd_tmp; + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) + for (int32_t i = 1; i < n_outputs; ++i) { + int32_t j = i; + size_t out_id_tmp = out_ids[i]; + while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } + if (i - j == 0) { continue; } + memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); + out_ids[j] = out_id_tmp; + if (ctx->logits_size > 0) { + // only allocate once something needs to be moved + if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } + memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); + memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); + memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + } + if (ctx->embd_size > 0) { + // only allocate once something needs to be moved + if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } + memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); + memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); + memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->logits; } @@ -19727,6 +19961,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->embd; } From 72eea49224e5b90263de08f8cddc6010353841eb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 12:24:19 -0400 Subject: [PATCH 15/28] llama : fix batch split output count for embeddings --- llama.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 6878bc8936046..7c6afa7d1fbe8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13730,7 +13730,9 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + lctx.sbatch.from_batch(batch_all, n_embd, + /* legacy_split */ rs_self.size == 0, + /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { @@ -13740,6 +13742,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits + // For now, only use equal splits for recurrent or hybrid model architectures llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens; From 18d1c140471da9443db9e0b67f61ccf540e113c0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 15:01:34 -0400 Subject: [PATCH 16/28] llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. --- llama.cpp | 50 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7c6afa7d1fbe8..d44dfe7b20933 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19828,33 +19828,43 @@ void llama_synchronize(struct llama_context * ctx) { static void llama_reorder_outputs(struct llama_context * ctx) { std::vector & out_ids = ctx->sbatch.out_ids; if (!out_ids.empty()) { - std::vector logits_tmp; - std::vector embd_tmp; uint32_t n_vocab = ctx->model.hparams.n_vocab; uint32_t n_embd = ctx->model.hparams.n_embd; int32_t n_outputs = ctx->n_outputs; GGML_ASSERT((size_t) n_outputs == out_ids.size()); - // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) - for (int32_t i = 1; i < n_outputs; ++i) { - int32_t j = i; - size_t out_id_tmp = out_ids[i]; - while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } - if (i - j == 0) { continue; } - memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); - out_ids[j] = out_id_tmp; + { + bool is_already_sorted = true; + for (int32_t i = 0; i < n_outputs - 1; ++i) { + if (out_ids[i] > out_ids[i + 1]) { + is_already_sorted = false; + break; + } + } + if (is_already_sorted) { + out_ids.clear(); + return; + } + } + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); if (ctx->logits_size > 0) { - // only allocate once something needs to be moved - if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } - memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); - memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); - memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } } if (ctx->embd_size > 0) { - // only allocate once something needs to be moved - if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } - memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); - memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); - memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } } } std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); From 61200ef29fc0e76f264ada583b77e9228120779f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 16:41:22 -0400 Subject: [PATCH 17/28] llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. --- llama.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index d44dfe7b20933..62d66c2bc2831 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot( if (cell.tail_rc == 0) { cache.rs.clear_cell(cell); } else { - // TODO: does this always work correctly - // even if there are more than one seq_node in this cell? + // Find the seq_id of the first tail of this cell + llama_seq_id seq_id = -1; + for (llama_rs_seq_node & seq_node : cell.seq_nodes) { + if (seq_node.is_tail()) { + seq_id = seq_node.seq_id; + break; + } + } + GGML_ASSERT(seq_id != -1); // Which seq_id of the batch is it? - llama_seq_id seq_id = cell.seq_nodes[0].seq_id; int32_t nth_seq_id = -1; for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { if (seq_id == batch.seq_id[s][0]) { From eb589d5e3664b784aef5326aa14dd21889eb1948 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 23:05:13 -0400 Subject: [PATCH 18/28] llama : avoid copies for simple batch splits --- llama.cpp | 81 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index 62d66c2bc2831..ce96d7b5503d2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3143,19 +3143,29 @@ struct llama_sbatch { GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { - for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; } } else { ubatch.token = nullptr; } if (batch->embd) { - for (size_t i = 0; i < length; ++i) { - memcpy( - ubatch.embd + n_embd * (ubatch.n_tokens + i), - batch->embd + n_embd * ids[seq.offset + i], - n_embd * sizeof(float) - ); + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + seq.offset; } } else { ubatch.embd = nullptr; @@ -3163,8 +3173,13 @@ struct llama_sbatch { // from here on, the else branches are deprecated; // they are helpers for smoother batch API transition if (batch->pos) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3172,7 +3187,7 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (seq.n_seq_id > 0) { + if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; @@ -3181,9 +3196,10 @@ struct llama_sbatch { ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { + // simple split if (batch->n_seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + ubatch.n_seq_id = batch->n_seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3192,7 +3208,7 @@ struct llama_sbatch { } if (batch->seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + ubatch.seq_id = batch->seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3201,11 +3217,19 @@ struct llama_sbatch { } } if (batch->logits) { - for (size_t i = 0; i < length; ++i) { - size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; - ubatch.output[ubatch.n_tokens + i] = is_output; - if (is_output) { out_ids.push_back(id); } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } } } else if (logits_all) { for (size_t i = 0; i < length; ++i) { @@ -3222,18 +3246,18 @@ struct llama_sbatch { } } if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { - ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits seq.offset += length; seq.length -= length; n_tokens -= length; GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } - // legacy split, unknown number of sequences of unequal lengths - llama_ubatch split_slice(size_t n_ubatch) { + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); ubatch.equal_seqs = false; @@ -3241,7 +3265,6 @@ struct llama_sbatch { llama_sbatch_seq & s = seq[0]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits - // TODO: reduce copies add_seq_to_ubatch(ubatch, s, length); } return ubatch; @@ -3254,7 +3277,7 @@ struct llama_sbatch { if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits // smallest first, because it's easier to split this way; // starting from the end to pop in constant time. for (size_t i = seq.size(); i-- > 0;) { @@ -3282,13 +3305,13 @@ struct llama_sbatch { if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } return ubatch; } - void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) { + void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -3302,7 +3325,7 @@ struct llama_sbatch { for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } - if (legacy_split) { + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; s.n_seq_id = 0; @@ -13737,7 +13760,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch_all, n_embd, - /* legacy_split */ rs_self.size == 0, + /* simple_split */ rs_self.size == 0, /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -13749,7 +13772,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits // For now, only use equal splits for recurrent or hybrid model architectures - llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); + llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens; // count the outputs in this u_batch From 8fb57ac0fbf21d09abd21f3c167ee2cec8bb7094 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 2 Jun 2024 22:49:24 -0400 Subject: [PATCH 19/28] llama : use im2col and mul_mat to perform convolution for Mamba This removes the need for ggml_ssm_conv!!! But performance seems slighly worse on my system, especially for prompt processing. Maybe ggml_mul_mat isn't optimized for small row sizes? More performance testing is necessary until GGML_OP_SSM_CONV is removed. * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. --- ggml.c | 121 +++++++++++++++++++++--------------------------------- ggml.h | 3 +- llama.cpp | 83 +++++++++++++++++++++++++------------ 3 files changed, 103 insertions(+), 104 deletions(-) diff --git a/ggml.c b/ggml.c index 426501015bbe5..253b3fa416e93 100644 --- a/ggml.c +++ b/ggml.c @@ -7124,26 +7124,24 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c) { - GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_3d(x)); + GGML_ASSERT(ggml_is_3d(sx)); GGML_ASSERT(ggml_is_matrix(c)); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; - const int64_t n_t = x->ne[1]; // tokens per sequence - const int64_t n_s = s->ne[2]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - GGML_ASSERT(s->ne[0] == d_conv - 1); - GGML_ASSERT(s->ne[1] == d_inner); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(x->ne[2] == n_s); + // TODO: maybe support other strides than 1? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); bool is_node = false; - if (s->grad || x->grad || c->grad) { + if (sx->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -7152,9 +7150,8 @@ struct ggml_tensor * ggml_ssm_conv( result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = s; - result->src[1] = x; - result->src[2] = c; + result->src[0] = sx; + result->src[1] = c; return result; } @@ -7203,8 +7200,8 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // y - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); + // concatenated y + ssm_states + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -16252,22 +16249,21 @@ static void ggml_compute_forward_ssm_conv_f32( return; } - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // tokens per sequence - const int n_s = src0->ne[2]; // number of sequences in the batch + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_are_same_shape(src1, dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // rows per thread @@ -16278,54 +16274,28 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? - // This would avoid having to copy into an intermediate buffer, but the state would be bigger. - float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - for (int i3 = 0; i3 < n_s; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - - // copy the state into working memory - // can't use memcpy because (d_conv) != (d_conv - 1) - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - for (int i2 = 0; i2 < n_t; ++i2) { - float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - - // shift state left - memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + // TODO: transpose the output for smaller strides for big batches? // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } - - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + + // d_conv for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; } } - - // copy the state out of it - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; - } - } } } @@ -16368,7 +16338,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16377,6 +16347,10 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src5->nb[0] == sizeof(float)); // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16388,13 +16362,17 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } // d_inner for (int i1 = 0; i1 < ir; ++i1) { @@ -16406,7 +16384,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[i0]; s[i] = state; @@ -19577,13 +19555,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; - case GGML_OP_SSM_CONV: - { - const int64_t d_conv = node->src[2]->ne[0]; - const int64_t d_inner = node->src[0]->ne[1]; - - cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); - } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 9df601e2cd826..c772febf0aafa 100644 --- a/ggml.h +++ b/ggml.h @@ -1803,8 +1803,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( diff --git a/llama.cpp b/llama.cpp index ce96d7b5503d2..ecdcf3a4e7096 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2827,11 +2827,13 @@ struct llama_rs_cache { n_shared_tail_cells += 1; n_seqs -= 1; } - } else if (rs_cell.is_empty()) { - // from shared to unique - n_seqs += 1; - if (prev_cell.tail_rc == 1) { - // it was the last tail of the previous cell + } else { + if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + } + if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + // from last shared to fully tail n_shared_tail_cells -= 1; } } @@ -8683,6 +8685,18 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + // copy states which won't be changed further (between n_seqs and n_rs) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, conv_states, (d_conv - 1)*d_inner*(n_rs - n_seqs), n_seqs*(conv_states->nb[2])), + ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*(d_conv - 1)*d_inner*ggml_element_size(conv_states_all)))); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, ssm_states, d_state*d_inner*(n_rs - n_seqs), n_seqs*(ssm_states->nb[2])), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + // the part of the states that will be used and modified struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); @@ -8698,28 +8712,43 @@ static struct ggml_tensor * llm_build_mamba( // conv { - // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. - // And also because this uses much less memory for large batches (4 times less when d_conv is 4). - // The equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, // then sum the elements of each row, // (the last two steps are a dot product over rows (also doable with mul_mat)) // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. // For simultaneous sequences, all sequences need to have the same length. - x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); - // ensure conv is updated before copying into the recurrent state cache - ggml_build_forward_expand(graph, x); + // For some reason, im2col expects a F16 kernel, but doesn't even read from it. + // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. + // => { d_conv * d_inner, n_seq_tokens, n_seqs} + x = ggml_im2col(ctx, + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), + conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); - ggml_build_forward_expand(graph, - ggml_cpy(ctx, conv_states, - ggml_view_1d(ctx, conv_states_all, - (d_conv - 1)*(d_inner)*(n_rs), - rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); + + // => {1, 1, d_inner, n_seq_tokens * n_seqs} + x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); + + // Alternatively, this does the same as the above + // x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8746,16 +8775,16 @@ static struct ggml_tensor * llm_build_mamba( // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} - struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - - // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache - ggml_build_forward_expand(graph, y); + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); // store last states ggml_build_forward_expand(graph, - ggml_cpy(ctx, ssm_states, - ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); // TODO: skip computing output earlier for unused tokens From 17f6c1ef3bdb8332393ea8da008023134291b0c3 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 3 Jun 2024 00:41:15 -0400 Subject: [PATCH 20/28] llama : fix .base() compilation error on Windows --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ecdcf3a4e7096..d4736473169a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2574,7 +2574,7 @@ struct llama_rs_cache { std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); // The iterator needs to point inside the correct vector - GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); + GGML_ASSERT(&(*node_iter) >= rs_cell.seq_nodes.data() && &(*node_iter) < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); if (node_iter != rs_cell.seq_nodes.end()) { // update the tree llama_rs_seq_node node = *node_iter; From fee3c1d740c0e027c81e2f2f3fb48d619857175f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 3 Jun 2024 13:49:56 -0400 Subject: [PATCH 21/28] llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. --- ggml.c | 5 ----- llama.cpp | 20 ++++++++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index 253b3fa416e93..1a37ff2f070be 100644 --- a/ggml.c +++ b/ggml.c @@ -10992,11 +10992,6 @@ static void ggml_compute_forward_concat_f32( GGML_TENSOR_BINARY_OP_LOCALS - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - const int32_t dim = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(dim >= 0 && dim < 4); diff --git a/llama.cpp b/llama.cpp index d4736473169a1..36b824d566b90 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8713,7 +8713,7 @@ static struct ggml_tensor * llm_build_mamba( // conv { // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} - struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0); + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0); // copy last (d_conv - 1) columns back into the state cache struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); @@ -8734,6 +8734,8 @@ static struct ggml_tensor * llm_build_mamba( // and then you're left with the resulting x tensor. // For simultaneous sequences, all sequences need to have the same length. + // TODO: remove unused implementations +#if 0 // For some reason, im2col expects a F16 kernel, but doesn't even read from it. // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. // => { d_conv * d_inner, n_seq_tokens, n_seqs} @@ -8741,14 +8743,24 @@ static struct ggml_tensor * llm_build_mamba( ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); + #if 0 + // TODO: CUDA, SYCL, and Vulkan don't (yet) support broadcasting the ne[3] dimension on MUL_MAT x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); // => {1, 1, d_inner, n_seq_tokens * n_seqs} x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); - x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); + #else + x = ggml_reshape_4d(ctx, x, d_conv, d_inner, n_seq_tokens, n_seqs); - // Alternatively, this does the same as the above - // x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + // NOTE: it seems this is very slighly more performant than MUL_MAT on CPU for small row sizes + // => {1, d_inner, n_seq_tokens, n_seqs} + x = ggml_sum_rows(ctx, ggml_mul(ctx, x, model.layers[il].ssm_conv1d)); + #endif + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); +#else + // Alternatively, this does the same as the above, but faster + x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); +#endif // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); From 372482dffeecc25b8eec24ad672ec66bd9baa55c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 8 Jun 2024 17:58:40 -0400 Subject: [PATCH 22/28] llama : rename llama_cache to llama_past This can be changed back later if the name change is wrong. I was renaming the functions anyway to generalize kv-cache-related functions to hybrid and recurrent model architectures. I think llama_past is a better name than llama_cache for a combined kv cache and recurrent state cache, because the states it contains pretty much always come before the newly-added ones for any particular sequence. Also 'llama_past_clear' sounds more obvious in what it does than 'llama_kv_cache_clear'. The future is what the models generate. (For embeddings, the kv cache isn't really used anyway) Still, I'm open to better suggestions. --- llama.cpp | 104 +++++++++++++++++++++++++++--------------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4a2fb3a92f452..4b84313cf8d37 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2877,7 +2877,7 @@ struct llama_rs_cache { } }; -struct llama_cache { +struct llama_past { // key + value cache for self attention llama_kv_cache kv; @@ -2896,7 +2896,7 @@ struct llama_cache { return size; } - ~llama_cache() { + ~llama_past() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -3426,7 +3426,7 @@ struct llama_context { const llama_model & model; // key + value cache for self-attention, and/or recurrent state cache - struct llama_cache cache; + struct llama_past cache; // sequence-length-aware batch splitting llama_sbatch sbatch; @@ -3604,8 +3604,8 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { // kv and rs cache helpers // -static bool llama_cache_init( - struct llama_cache & cache, +static bool llama_past_init( + struct llama_past & cache, const llama_context * ctx, ggml_type type_k, ggml_type type_v, @@ -3713,11 +3713,11 @@ static bool llama_cache_init( // no buffer was needed, so this is fine return true; } - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to allocate buffer for past cache\n", __func__); return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s past cache size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3728,8 +3728,8 @@ static bool llama_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_cache_find_slot( - struct llama_cache & cache, +static bool llama_past_find_slot( + struct llama_past & cache, const struct llama_ubatch & batch) { const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; @@ -4001,7 +4001,7 @@ static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { return 0; } -static void llama_cache_clear(struct llama_cache & cache) { +static void llama_past_clear(struct llama_past & cache) { if (cache.kv.size > 0) { for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; @@ -4035,8 +4035,8 @@ static void llama_cache_clear(struct llama_cache & cache) { } } -static llama_pos llama_cache_seq_rm( - struct llama_cache & cache, +static llama_pos llama_past_seq_rm( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -4134,8 +4134,8 @@ static llama_pos llama_cache_seq_rm( return n_past; } -static llama_pos llama_cache_seq_cp( - struct llama_cache & cache, +static llama_pos llama_past_seq_cp( + struct llama_past & cache, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, @@ -4199,7 +4199,7 @@ static llama_pos llama_cache_seq_cp( return n_past; } -static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { +static void llama_past_seq_keep(struct llama_past & cache, llama_seq_id seq_id) { if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; @@ -4249,8 +4249,8 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id } } -static void llama_cache_seq_add( - struct llama_cache & cache, +static void llama_past_seq_add( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4317,8 +4317,8 @@ static void llama_cache_seq_add( } } -static void llama_cache_seq_div( - struct llama_cache & cache, +static void llama_past_seq_div( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4358,7 +4358,7 @@ static void llama_cache_seq_div( } } -static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_past_seq_pos_max(struct llama_past & cache, llama_seq_id seq_id) { llama_pos result = -1; if (cache.rs.size > 0) { @@ -13911,7 +13911,7 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - if (!llama_cache_find_slot(lctx.cache, u_batch)) { + if (!llama_past_find_slot(lctx.cache, u_batch)) { return 1; } @@ -17981,7 +17981,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { + if (!llama_past_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -18515,85 +18515,85 @@ int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.rs.used; } -void llama_cache_clear(struct llama_context * ctx) { - llama_cache_clear(ctx->cache); +void llama_past_clear(struct llama_context * ctx) { + llama_past_clear(ctx->cache); } // deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_cache_clear(ctx); + llama_past_clear(ctx); } -llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +llama_pos llama_past_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + return llama_past_seq_rm(ctx->cache, seq_id, p0, p1); } // deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); + llama_pos n_past = llama_past_seq_rm(ctx, seq_id, p0, p1); return n_past >= p0; } -llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +llama_pos llama_past_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { return 0; } if (seq_id_src == seq_id_dst) { - return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; + return llama_past_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_past_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } // deprecated void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); + llama_past_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); } -void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +void llama_past_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_keep(ctx->cache, seq_id); + llama_past_seq_keep(ctx->cache, seq_id); } // deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_cache_seq_keep(ctx, seq_id); + llama_past_seq_keep(ctx, seq_id); } -void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +void llama_past_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (delta == 0) { return; } - llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_past_seq_add(ctx->cache, seq_id, p0, p1, delta); } // deprecated void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_cache_seq_add(ctx, seq_id, p0, p1, delta); + llama_past_seq_add(ctx, seq_id, p0, p1, delta); } -void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_past_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (d == 1) { return; } - llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_past_seq_div(ctx->cache, seq_id, p0, p1, d); } // deprecated void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - llama_cache_seq_div(ctx, seq_id, p0, p1, d); + llama_past_seq_div(ctx, seq_id, p0, p1, d); } -llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { +llama_pos llama_past_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } - return llama_cache_seq_pos_max(ctx->cache, seq_id); + return llama_past_seq_pos_max(ctx->cache, seq_id); } // deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + llama_pos max_pos = llama_past_seq_pos_max(ctx, seq_id); return max_pos < 0 ? 0 : max_pos; } @@ -19345,7 +19345,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(cache.rs.size == 0); // not implemented // Wipe the slot - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); const uint8_t * inp = src; @@ -19402,7 +19402,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } batch.n_seq_id[0] = 1; batch.seq_id[0] = &dest_seq_id; - if (!llama_cache_find_slot(cache, batch)) { + if (!llama_past_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; } @@ -19427,7 +19427,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_type_i_ref); const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; if (k_type_i != k_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return 0; } @@ -19438,7 +19438,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_size_row_ref); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); return 0; } @@ -19459,7 +19459,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -19470,7 +19470,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_row_ref); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); return 0; } @@ -19490,7 +19490,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -19501,7 +19501,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_el_ref); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); return 0; } From 43d8d4bf9e88df10203f7d8d4a1107b84bebbcfd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 10 Jun 2024 14:44:42 -0400 Subject: [PATCH 23/28] examples : replace llama_kv_cache_seq_* with llama_past_seq_* --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 4 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 4 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 8 +-- .../llama.cpp.swift/LibLlama.swift | 8 +-- examples/lookahead/lookahead.cpp | 13 ++--- examples/lookup/lookup.cpp | 3 +- examples/main/main.cpp | 21 +++++--- examples/parallel/parallel.cpp | 10 ++-- examples/passkey/passkey.cpp | 28 +++++------ examples/perplexity/perplexity.cpp | 12 ++--- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 49 ++++++++++--------- examples/speculative/speculative.cpp | 28 ++++++----- llama.cpp | 3 +- llama.h | 28 +++++------ 23 files changed, 127 insertions(+), 114 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1591790e6df4c..d04e047410778 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2366,7 +2366,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_clear(lctx); + llama_past_clear(lctx); llama_synchronize(lctx); llama_reset_timings(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 718f0a61a1878..114dd811ee3f9 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index dbbd06da58183..443a03d575ea4 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_past_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 62d9b144d3340..888cf9e8e8c34 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -112,7 +112,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } if (n_parallel > 1) { diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e003d9e..9a7c32d6b8ca2 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -25,7 +25,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2135157916c97..dd389ac004383 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -43,7 +43,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, false); // run model @@ -97,7 +97,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo const llama_model * mdl = llama_get_model(ctx); llama_token eos_token = llama_token_eos(mdl); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e18f495630616..c81590a3f8c88 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -455,7 +455,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 0e4ec79c693fa..0a74b93abd698 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -380,8 +380,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5c31548a6c25c..d48eb245daa80 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1360,7 +1360,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // warmup run if (t.n_prompt > 0) { @@ -1372,7 +1372,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 874158ef0f98f..57ee5a650893a 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); + llama_past_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 737f882fb2d2e..50fcaa12d6145 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -214,7 +214,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp_start = ggml_time_us() @@ -227,7 +227,7 @@ actor LlamaContext { // bench text generation - llama_kv_cache_clear(context) + llama_past_clear(context) let t_tg_start = ggml_time_us() @@ -246,7 +246,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -296,7 +296,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_cache_clear(context) + llama_past_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index fb20ad93f9c1d..7f6e42e8d2810 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -438,17 +438,18 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_cache_seq_keep(ctx, seq_id_best); - llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + llama_past_seq_keep(ctx, seq_id_best); + llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_past_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 80ecd925d5962..db861d6ad99f0 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -195,7 +195,8 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, 0, n_past, -1); llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b97b7b7937f02..446fe035c3d25 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -299,6 +299,10 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } + + // remove any "future" tokens that we might have inherited from the previous session + n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_TEE("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -310,9 +314,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } - - // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( @@ -325,6 +326,8 @@ int main(int argc, char ** argv) { LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); + } else { + session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -535,8 +538,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -563,9 +566,9 @@ int main(int argc, char ** argv) { LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -579,6 +582,8 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { + // TODO: are the session tokens guaranteed to all be matching here? + // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7faeaec975ae3..f684788043450 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -232,9 +232,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, -1, -1); + llama_past_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -371,8 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_past_seq_rm(ctx, client.id + 1, -1, -1); + llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a9..c6564c5cfd4c7 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,11 +126,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -160,12 +160,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -191,12 +191,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 0bd78c21a86a1..ad03b3bb5552b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f70ae2a..bd7d06d371c2e 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,7 +81,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 00c2277ac2827..974dc3c3ed5f5 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -192,7 +192,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_cache_clear(ctx3); + llama_past_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6ffaa8d9fe637..a04c47bae21e0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,7 +1107,7 @@ struct server_context { LOG_VERBOSE("clearing KV cache", {}); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); clean_kv_cache = false; } @@ -1151,7 +1151,7 @@ struct server_context { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } @@ -1824,7 +1824,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_past_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -1939,8 +1939,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -2155,23 +2155,28 @@ struct server_context { } // keep only the common part - int p0 = (int) system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past; + + // for recurrent and hybrid models, sometimes it goes back further than asked + llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1); + + if (new_p0 < p0) { + GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size()); + + slot.n_past -= p0 - new_p0; + if (slot.ga_i > 0) { + // TODO: test with an hybrid model (e.g. Jamba) + slot.n_past_se -= p0 - new_p0; } - // there is no common part left (except for the system prompt) - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? + // TODO: find a way to avoid rolling back the sampling context twice llama_sampling_reset(slot.ctx_sampling); + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + + p0 = new_p0; } // remove the non-common part from the cache @@ -2273,9 +2278,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a7a38..3a1ef06a5e6b4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -394,14 +394,15 @@ int main(int argc, char ** argv) { { LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); - - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_dft, s_keep); + llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_dft, 0); + + // FIXME: recurrent and hybrid models + llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_past_seq_keep(ctx_tgt, s_keep); + llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -418,7 +419,8 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_dft); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -474,8 +476,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -553,9 +555,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/llama.cpp b/llama.cpp index 4b84313cf8d37..2233161d8f938 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2126,7 +2126,6 @@ struct llama_ubatch { llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - // FIXME: make all uses of this use n_seqs int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] @@ -18992,7 +18991,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; diff --git a/llama.h b/llama.h index 0d9d522569632..4ecfc5f3e0a91 100644 --- a/llama.h +++ b/llama.h @@ -583,11 +583,11 @@ extern "C" { LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed - LLAMA_API void llama_cache_clear( + LLAMA_API void llama_past_clear( struct llama_context * ctx); LLAMA_API DEPRECATED(void llama_kv_cache_clear( struct llama_context * ctx), - "use llama_cache_clear instead"); + "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // seq_id < 0 : match any sequence @@ -595,7 +595,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the seq_id) // which is only meaningful to handle for partial removals. - LLAMA_API llama_pos llama_cache_seq_rm( + LLAMA_API llama_pos llama_past_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -605,7 +605,7 @@ extern "C" { llama_seq_id seq_id, llama_pos p0, llama_pos p1), - "use llama_cache_seq_rm instead, and handle its return value for partial removals"); + "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence @@ -613,7 +613,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the destination seq_id) // which is only meaningful to handle when partially copying. - LLAMA_API llama_pos llama_cache_seq_cp( + LLAMA_API llama_pos llama_past_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, @@ -625,16 +625,16 @@ extern "C" { llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1), - "use llama_cache_seq_cp instead, and handle its return value for partial copies"); + "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_cache_seq_keep( + LLAMA_API void llama_past_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_keep instead"); + "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -642,7 +642,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_add( + LLAMA_API void llama_past_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -654,7 +654,7 @@ extern "C" { llama_pos p0, llama_pos p1, llama_pos delta), - "use llama_cache_seq_add instead"); + "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -662,7 +662,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_div( + LLAMA_API void llama_past_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -674,16 +674,16 @@ extern "C" { llama_pos p0, llama_pos p1, int d), - "use llama_cache_seq_div instead"); + "use llama_past_seq_div instead"); // Returns the largest position present in the KV and/or RS cache for the specified sequence - LLAMA_API llama_pos llama_cache_seq_pos_max( + LLAMA_API llama_pos llama_past_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From 33425a7e1ed366082a2dbf64f2485531471515e0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 12 Jun 2024 12:57:02 -0400 Subject: [PATCH 24/28] mamba : fix non-contiguous usage of ggml_silu --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 2233161d8f938..37190bf1c48b0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8867,7 +8867,7 @@ static struct ggml_tensor * llm_build_mamba( // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx, y, ggml_silu(ctx, z)); + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); From fcb889cf7fb6588a6565f4cc6373be3f53ff25ca Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 20:31:30 -0400 Subject: [PATCH 25/28] llama : session saving and reloading for hybrid models --- include/llama.h | 4 +- src/llama.cpp | 519 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 390 insertions(+), 133 deletions(-) diff --git a/include/llama.h b/include/llama.h index 59f38936fbed7..6f6e73c901091 100644 --- a/include/llama.h +++ b/include/llama.h @@ -38,10 +38,10 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_STATE_SEQ_VERSION 3 #ifdef __cplusplus extern "C" { diff --git a/src/llama.cpp b/src/llama.cpp index 213a27cc8e2db..0f55196cf8edb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19839,8 +19839,28 @@ struct llama_data_write { } } + void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = rs_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_node : cell.seq_nodes) { + write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + } + } + } + } + } + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_kv_cache & kv_self = ctx->cache.kv; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t v_trans = kv_self.v_trans ? 1 : 0; @@ -19849,12 +19869,10 @@ struct llama_data_write { write(&v_trans, sizeof(v_trans)); write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; @@ -19874,7 +19892,7 @@ struct llama_data_write { if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19895,7 +19913,7 @@ struct llama_data_write { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19922,43 +19940,151 @@ struct llama_data_write { } } - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + void write_rs_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_hparams & hparams = ctx->model.hparams; - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = kv_self.size; - for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto & cell = kv_self.cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == kv_self.size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = kv_self.size; + const uint32_t n_layer = hparams.n_layer; + + write(&n_layer, sizeof(n_layer)); + + // Iterate and write all recurrent states, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Write type + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + write(&r_type_i, sizeof(r_type_i)); + + // Write row size + const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + write(&r_size_row, sizeof(r_size_row)); + + // Read each range of cells of r_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * r_size_row; + write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size); + } + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Write type + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + write(&s_type_i, sizeof(s_type_i)); + + // Write row size + const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + write(&s_size_row, sizeof(s_size_row)); + + // Read each range of cells of s_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * s_size_row; + write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size); + } + } + } + + void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->cache.kv; + const struct llama_rs_cache & rs_self = ctx->cache.rs; + std::vector> kv_cell_ranges; // ranges, from inclusive, to exclusive + std::vector> rs_cell_ranges; // ranges, from inclusive, to exclusive + uint32_t kv_cell_count = 0; + uint32_t rs_cell_count = 0; + // Transformer KV cache + { + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++kv_cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } } } + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : kv_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(kv_cell_count == cell_count_check); } - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, kv_self.size); + // Recurrent state cache + if (seq_id == -1) { + // Find all the ranges of cells + uint32_t cell_range_begin = rs_self.size; + for (uint32_t i = 0; i < rs_self.size; ++i) { + const auto & cell = rs_self.cells[i]; + if (!cell.is_empty()) { + ++rs_cell_count; + if (cell_range_begin == rs_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = rs_self.size; + } + } + } + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size); + } + + } else { + // Find the cell ranges of the specified seq_id + if ((size_t) seq_id < rs_self.seq_tails.size()) { + int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail; + if (tail_cell_id >= 0) { + ++rs_cell_count; + rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1); + } + } } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; + { + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : rs_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(rs_cell_count == cell_count_check); } - GGML_ASSERT(cell_count == cell_count_check); - write(&cell_count, sizeof(cell_count)); + write(&kv_cell_count, sizeof(kv_cell_count)); + write(&rs_cell_count, sizeof(rs_cell_count)); - write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); + if (seq_id == -1) { + // write metadata for both when the whole cache needs to be saved + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } else if (kv_cell_count > 0) { + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + } else { + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } + if (kv_cell_count > 0) { + write_kv_cache_data(ctx, kv_cell_ranges); + } + if (rs_cell_count > 0) { + write_rs_cache_data(ctx, rs_cell_ranges); + } } }; @@ -20050,108 +20176,98 @@ struct llama_data_read { } } - bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { - struct llama_kv_cache & kv_self = ctx->kv_self; + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_kv_cache & kv_self = cache.kv; + + // whole KV cache restore + + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } - if (dest_seq_id != -1) { - // single sequence + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_pos pos; + uint32_t n_seq_id; - llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; + cell.pos = pos; - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); return false; } - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; + cell.seq_id.insert(seq_id); } + } - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore + kv_self.head = 0; + kv_self.used = cell_count; - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } + return true; + } - llama_kv_cache_clear(kv_self); + bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_rs_cache & rs_self = cache.rs; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = kv_self.cells[i]; + // whole RS cache restore - llama_pos pos; - uint32_t n_seq_id; + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__); + return false; + } - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t i = 0; i < cell_count; ++i) { + llama_rs_cell & cell = rs_self.cells[i]; - cell.pos = pos; + llama_pos pos; + uint32_t n_seq_id; - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - read_to(&seq_id, sizeof(seq_id)); + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - return false; - } + cell.pos = pos; + cell.src = i; - cell.seq_id.insert(seq_id); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (kv_self.recurrent) { - int32_t & tail = kv_self.cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; } - } - kv_self.head = 0; - kv_self.used = cell_count; - } + cell.insert_node(seq_id); - if (kv_self.recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = kv_self.head + i; - // make sure the recurrent states will keep their restored state - kv_self.cells[cell_id].src = cell_id; } } + rs_self.head = 0; + rs_self.used = cell_count; + + rs_self.rebuild(/* debug */ false); + return true; } bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->kv_self; + struct llama_kv_cache & kv_self = ctx->cache.kv; uint32_t v_trans; uint32_t n_layer; read_to(&v_trans, sizeof(v_trans)); @@ -20172,7 +20288,7 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -20192,15 +20308,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); - } + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); } if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20220,15 +20334,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); - } + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); } } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20256,29 +20368,174 @@ struct llama_data_read { return false; } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } return true; } - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { - uint32_t cell_count; - read_to(&cell_count, sizeof(cell_count)); + bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_rs_cache & rs_self = ctx->cache.rs; + uint32_t n_layer; + read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size); + return false; + } + + // For each layer, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Read type of key + int32_t r_type_i_ref; + read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + if (r_type_i != r_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t r_size_row_ref; + read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + if (r_size_row != r_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row); + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Read type of key + int32_t s_type_i_ref; + read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + if (s_type_i != s_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t s_size_row_ref; + read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row); + } + + return true; + } + + bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { + + if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); + return false; + } + + // single sequence + + llama_past & cache = ctx->cache; + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &seq_id; + if (!llama_past_find_slot(cache, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + if (cache.kv.size > 0) { + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size); + GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id)); + } + if (cache.rs.size > 0) { + GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size); + GGML_ASSERT(cache.rs.n == 1); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id)); + // Prevent cells from being cleared + for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) { + cache.rs.cells[i].src = i; + } + } + + return true; + } + + void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t kv_cell_count; + read_to(&kv_cell_count, sizeof(kv_cell_count)); + uint32_t rs_cell_count; + read_to(&rs_cell_count, sizeof(rs_cell_count)); + + bool res = true; + + if (seq_id == -1) { + llama_past_clear(ctx); + res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); + } else { + llama_past_seq_rm(ctx, seq_id, -1, -1); + // Only a single recurrent cell at most, + // because otherwise the cells can be shuffled when a slot is allocated + if (rs_cell_count > 1) { + LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__); + res = false; + } + res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id); + } - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count); if (!res) { if (seq_id == -1) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); } else { - llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + llama_past_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); } @@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_logits(ctx); data_ctx.write_embeddings(ctx); - data_ctx.write_kv_cache(ctx); + data_ctx.write_cache(ctx); return data_ctx.get_size_written(); } @@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_logits(ctx); data_ctx.read_embeddings(ctx); - data_ctx.read_kv_cache(ctx); + data_ctx.read_cache(ctx); return data_ctx.get_size_read(); } @@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); - data_ctx.write_kv_cache(ctx, seq_id); + data_ctx.write_cache(ctx, seq_id); return data_ctx.get_size_written(); } @@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { llama_synchronize(ctx); - data_ctx.read_kv_cache(ctx, dest_seq_id); + data_ctx.read_cache(ctx, dest_seq_id); return data_ctx.get_size_read(); } From 9d3f44dad426acc26d35e3b6cf1462d3a3f43113 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:46:27 -0400 Subject: [PATCH 26/28] convert_hf : fix Jamba conversion --- convert_hf_to_gguf.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 00059bd01afca..e9bb4b20bd6d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2910,7 +2910,6 @@ def set_gguf_parameters(self): n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) ] - self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) self.gguf_writer.add_embedding_length(d_model) @@ -2979,8 +2978,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield new_name, data_torch - def write_tensors(self): - super().write_tensors() + def prepare_tensors(self): + super().prepare_tensors() if self._experts is not None: # flatten `list[dict[str, Tensor]]` into `list[str]` @@ -2988,20 +2987,6 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - # same as Mamba - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del n_dims # unused - - return bid is not None and new_name in ( - self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ - gguf.MODEL_TENSOR.SSM_CONV1D, - gguf.MODEL_TENSOR.SSM_X, - gguf.MODEL_TENSOR.SSM_DT, - gguf.MODEL_TENSOR.SSM_A, - gguf.MODEL_TENSOR.SSM_D, - ] - ) - @Model.register("CohereForCausalLM") class CommandR2Model(Model): From 5f62db790b8e548eb7db0f69a9fadb7f809f6c96 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:50:27 -0400 Subject: [PATCH 27/28] llama : fix mixed signedness comparison --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 842de9118876c..cf7dccb384f2b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20963,7 +20963,7 @@ struct llama_data_read { bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { - if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0 || seq_id >= (llama_seq_id) llama_n_seq_max(ctx)) { LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); return false; } From 375de5b1f8c07b5bfdef7f00b738eb176f8431ba Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:59:24 -0400 Subject: [PATCH 28/28] llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index cf7dccb384f2b..043f3d7ec7853 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10806,7 +10806,7 @@ struct llm_build_context { ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);