diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3107b69f7e42e..8ce79d14604fd 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1404,6 +1404,48 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + weight = (weight * s).round().clamp(-1, 1) / s + scale = weight.abs().max().unsqueeze(0) + weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype) + weight = torch.sign(weight).type(dtype) + return weight.type(dtype), scale.type(torch.float32) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + new_name = self.map_tensor_name(name) + + if any(self.match_model_tensor_name(new_name, key, bid) for key in [ + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.FFN_UP, + gguf.MODEL_TENSOR.FFN_DOWN, + gguf.MODEL_TENSOR.FFN_GATE, + ]): + # transform weight into 1/0/-1 (in fp32) + weight_torch, scale_torch = self.weight_quant(data_torch) + yield (new_name, weight_torch) + yield (new_name.removesuffix(".weight") + ".scale", scale_torch) + else: + yield (new_name, data_torch) + + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fb20cfabbcab5..d266fbd43d8d6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + BITNET = auto() class MODEL_TENSOR(IntEnum): @@ -200,6 +201,8 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -237,6 +240,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.BITNET: "bitnet", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -288,6 +292,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -808,6 +814,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 81b4992a51eed..350035bd96a17 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -413,6 +413,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_KV_A_NORM: ( "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 ), + + MODEL_TENSOR.ATTN_SUB_NORM: ( + "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + ), + + MODEL_TENSOR.FFN_SUB_NORM: ( + "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + ), } # architecture-specific block mappings diff --git a/llama.cpp b/llama.cpp index a05a52b4234cd..c710ef82b746e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -225,6 +225,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_BITNET, LLM_ARCH_UNKNOWN, }; @@ -263,6 +264,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -500,6 +502,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1113,6 +1117,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2118,6 +2140,8 @@ struct llama_layer { struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * attn_q_a_norm; struct ggml_tensor * attn_kv_a_norm; + struct ggml_tensor * attn_sub_norm; + struct ggml_tensor * ffn_sub_norm; // attention struct ggml_tensor * wq; @@ -2185,6 +2209,15 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -4710,6 +4743,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6655,6 +6697,44 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_BITNET: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7295,7 +7375,10 @@ static struct ggml_tensor * llm_build_kqv( ggml_build_forward_expand(graph, cur); - cur = ggml_mul_mat(ctx, wo, cur); + if (wo) { + cur = ggml_mul_mat(ctx, wo, cur); + } + if (wo_b) { cb(cur, "kqv_wo", il); } @@ -11709,6 +11792,153 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_sub_norm", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); + cb(tmp, "ffn_up", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); + cb(cur, "ffn_gate", il); + + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + cb(cur, "ffn_down", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11932,6 +12162,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_BITNET: + { + result = llm.build_bitnet(); + } break; default: GGML_ASSERT(false); } @@ -16760,6 +16994,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: