Skip to content

Commit

Permalink
gemma2: add sliding window mask (ggerganov#8227)
Browse files Browse the repository at this point in the history
* gemma2: add sliding window mask

* fix data_swa uninitialized

* better naming

* add co-author

Co-authored-by: Arlo Phoenix <[email protected]>

* replace list with single tensor

* update

* llama : minor styling

* convert : add sanity check for query_pre_attn_scalar

* fix small typo in README

---------

Co-authored-by: Arlo Phoenix <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people authored Jul 1, 2024
1 parent 0ddeff1 commit 49122a8
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
**Tools:**

- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption

---

Expand Down
6 changes: 6 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,6 +2369,12 @@ def set_gguf_parameters(self):
self.gguf_writer.add_final_logit_softcapping(
self.hparams["final_logit_softcapping"]
)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])

# sanity check
attn_scalar = self.hparams["query_pre_attn_scalar"]
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Attention:
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None:
def add_relative_attn_buckets_count(self, value: int) -> None:
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)

def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down
99 changes: 68 additions & 31 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ enum llm_kv {
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
Expand Down Expand Up @@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
Expand Down Expand Up @@ -2085,6 +2087,7 @@ struct llama_hparams {
uint32_t n_head_kv;
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_ff;
Expand Down Expand Up @@ -2139,6 +2142,7 @@ struct llama_hparams {
if (this->n_head_kv != other.n_head_kv) return true;
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;
if (this->n_embd_head_k != other.n_embd_head_k) return true;
if (this->n_embd_head_v != other.n_embd_head_v) return true;
if (this->n_ff != other.n_ff) return true;
Expand Down Expand Up @@ -2649,17 +2653,18 @@ struct llama_context {
void * abort_callback_data = nullptr;

// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]

// control vectors
struct llama_control_vector cvec;
Expand Down Expand Up @@ -4709,6 +4714,8 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_GEMMA2:
{
hparams.n_swa = 4096; // default value of gemma 2
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
Expand Down Expand Up @@ -5419,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
Expand Down Expand Up @@ -7775,17 +7783,18 @@ struct llm_build_context {

ctx0 = ggml_init(params);

lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
}

void free() {
Expand All @@ -7804,7 +7813,6 @@ struct llm_build_context {
cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift);


for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * tmp =
Expand Down Expand Up @@ -7939,16 +7947,27 @@ struct llm_build_context {
}

struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} else {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
}
lctx.inp_KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);

return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
}

struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
GGML_ASSERT(hparams.n_swa > 0);

lctx.inp_KQ_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);

return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
}

struct ggml_tensor * build_inp_mean() {
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
cb(lctx.inp_mean, "inp_mean", -1);
Expand Down Expand Up @@ -11029,9 +11048,14 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// gemma 2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);

for (int il = 0; il < n_layer; ++il) {
// (il % 2) layers use SWA
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
Expand Down Expand Up @@ -11067,7 +11091,7 @@ struct llm_build_context {

cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}

cur = llm_build_norm(ctx0, cur, hparams,
Expand Down Expand Up @@ -12670,7 +12694,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

float * data = (float *) lctx.inp_KQ_mask->data;
float * data = (float *) lctx.inp_KQ_mask->data;
float * data_swa = nullptr;

if (lctx.inp_KQ_mask_swa) {
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
}

// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
Expand All @@ -12692,6 +12721,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;

// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}

Expand Down

0 comments on commit 49122a8

Please sign in to comment.