From 076b4a197bac94d589e3baec57111e6a8e9cf4ef Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Wed, 5 Jun 2024 16:15:28 +0800 Subject: [PATCH 01/32] hf bitnet v1 --- convert-hf-to-gguf.py | 20 ++ ggml.c | 111 +++++++- ggml.h | 7 + gguf-py/gguf/constants.py | 24 ++ gguf-py/gguf/tensor_mapping.py | 8 + llama.cpp | 248 ++++++++++++++++- tokenization_bitnet.py | 482 +++++++++++++++++++++++++++++++++ 7 files changed, 897 insertions(+), 3 deletions(-) create mode 100644 tokenization_bitnet.py diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ad071b97404f7..c79d6a0129377 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1390,6 +1390,26 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # return [(self.map_tensor_name(name), data_torch)] @Model.register("GrokForCausalLM") class GrokModel(Model): diff --git a/ggml.c b/ggml.c index 8869e146ab2b8..4c3e6f72371f6 100644 --- a/ggml.c +++ b/ggml.c @@ -2621,6 +2621,22 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } +inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, const float * x, float min) { + float max = min; + for (int i = 0; i < n; ++i) { + max = MAX(max, fabs(x[i])); + } + *s = max; +} +inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) { + for (int i = 0; i < n; ++i) { + s[i] = round(x[i] * scale); + if (s[i] > max) s[i] = max; + if (s[i] < min) s[i] = min; + s[i] /= scale; + } +} + // // data types // @@ -2709,9 +2725,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + + "BITLINEAR_QUANT" }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2797,9 +2815,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + + "bitlinear(x)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4830,6 +4850,28 @@ struct ggml_tensor * ggml_mean( return result; } +// ggml_bitlinear_quant for bitnet + +struct ggml_tensor * ggml_bitlinear_quant( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + int64_t ne[GGML_MAX_DIMS] = { a->ne[0], a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, ggml_n_dims(a), ne); + + result->op = GGML_OP_BITLINEAR_QUANT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_argmax struct ggml_tensor * ggml_argmax( @@ -10740,6 +10782,62 @@ static void ggml_compute_forward_mean( } } +static void ggml_compute_forward_bitlinear_quant_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + assert(ne0 == ne00); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float rowmax = 0.00001; + ggml_vec_absmaxclamp_f32(ne00, &rowmax, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), 0.00001); + float s = 127 / rowmax; + + ggml_vec_scaleroundclamp_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + s, -128, 127); + } + } + } +} + +static void ggml_compute_forward_bitlinear_quant( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_bitlinear_quant_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_argmax static void ggml_compute_forward_argmax_f32( @@ -17318,6 +17416,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mean(params, tensor); } break; + case GGML_OP_BITLINEAR_QUANT: + { + ggml_compute_forward_bitlinear_quant(params, tensor->src[0], tensor); + } break; case GGML_OP_ARGMAX: { ggml_compute_forward_argmax(params, tensor); @@ -18484,6 +18586,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_BITLINEAR_QUANT: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ARGSORT: { GGML_ASSERT(false); // TODO: not implemented @@ -19249,6 +19355,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ case GGML_OP_GET_REL_POS: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: + case GGML_OP_BITLINEAR_QUANT: case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM2_F32: case GGML_OP_MAP_CUSTOM3_F32: diff --git a/ggml.h b/ggml.h index f38699698b1e9..98ef961323e2e 100644 --- a/ggml.h +++ b/ggml.h @@ -506,6 +506,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_BITLINEAR_QUANT, + GGML_OP_COUNT, }; @@ -993,6 +995,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // for bitnet + GGML_API struct ggml_tensor * ggml_bitlinear_quant( + struct ggml_context * ctx, + struct ggml_tensor * a); + // argmax along rows GGML_API struct ggml_tensor * ggml_argmax( struct ggml_context * ctx, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a3c024c8975f5..429f3818914c9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -148,6 +148,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + BITNET = auto() class MODEL_TENSOR(IntEnum): @@ -199,6 +200,8 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -236,6 +239,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.BITNET: "bitnet", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -287,6 +291,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -806,6 +812,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 83e3c4c3381a0..c81ec9d391936 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -410,6 +410,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_KV_A_NORM: ( "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 ), + + MODEL_TENSOR.ATTN_SUB_NORM: ( + "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + ), + + MODEL_TENSOR.FFN_SUB_NORM: ( + "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + ), } # architecture-specific block mappings diff --git a/llama.cpp b/llama.cpp index a3e944874cf80..9891ea958684b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -223,6 +223,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_BITNET, LLM_ARCH_UNKNOWN, }; @@ -261,6 +262,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -496,6 +498,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1108,6 +1112,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1985,6 +2007,8 @@ struct llama_layer { struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * attn_q_a_norm; struct ggml_tensor * attn_kv_a_norm; + struct ggml_tensor * attn_sub_norm; + struct ggml_tensor * ffn_sub_norm; // attention struct ggml_tensor * wq; @@ -4499,6 +4523,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6409,6 +6442,40 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_BITNET: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + } + + const uint32_t n_ff = hparams.n_ff; + model.layers.resize(n_layer); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6761,6 +6828,15 @@ static struct ggml_tensor * llm_build_norm( return cur; } +static struct ggml_tensor * llm_build_qbitlinear( + struct ggml_context * ctx, + struct ggml_tensor * cur) + { + return ggml_bitlinear_quant(ctx, cur); + + return cur; + } + static struct ggml_tensor * llm_build_ffn( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -6963,6 +7039,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, + struct ggml_tensor * attn_sub_norm, int32_t n_tokens, int32_t n_kv, float kq_scale, @@ -7057,6 +7134,17 @@ static struct ggml_tensor * llm_build_kqv( cb(cur, "kqv_merged_cont", il); } + if (model.arch == LLM_ARCH_BITNET) + { + cur = llm_build_norm(ctx, cur, hparams, + attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_sub_norm", il); + + // B2 for wo + cur = llm_build_qbitlinear(ctx, cur); + } + ggml_build_forward_expand(graph, cur); cur = ggml_mul_mat(ctx, wo, cur); @@ -7102,7 +7190,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, - q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); + q_cur, kq_mask, nullptr, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -11448,6 +11536,159 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + // B1.Q + cur = llm_build_qbitlinear(ctx0, cur); + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + cur = llm_build_kqv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_mask, model.layers[il].attn_sub_norm, n_tokens, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il + ); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // cur = llm_build_ffn(ctx0, cur, + // model.layers[il].ffn_up, NULL, + // model.layers[il].ffn_gate, NULL, + // model.layers[il].ffn_down, NULL, + // NULL, + // LLM_FFN_SILU, LLM_FFN_PAR, cb, il, hparams, model.layers[il].ffn_sub_norm, isbitnet); + // cb(cur, "ffn_out", il); + + + cur = llm_build_qbitlinear(ctx0, cur); + + struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + + cb(tmp, "ffn_up", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + + cb(cur, "ffn_gate", il); + + + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + // B4 for w2 + cur = llm_build_qbitlinear(ctx0, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cb(cur, "ffn_down", il); + + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11670,6 +11911,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_BITNET: + { + result = llm.build_bitnet(); + } break; default: GGML_ASSERT(false); } @@ -16677,6 +16922,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: diff --git a/tokenization_bitnet.py b/tokenization_bitnet.py new file mode 100644 index 0000000000000..09b482f72f2cd --- /dev/null +++ b/tokenization_bitnet.py @@ -0,0 +1,482 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for LLaMA.""" +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class BitnetTokenizer(PreTrainedTokenizer): + """ + Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Bitnet should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template \ No newline at end of file From 57dfc3bcdf953862e200fd12a1a46d834009dc22 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Wed, 5 Jun 2024 16:01:05 +0000 Subject: [PATCH 02/32] hf bitnet e2e v2 --- convert-hf-to-gguf.py | 17 ++++++++++++++--- llama.cpp | 2 -- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c79d6a0129377..8ba7119fc675e 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1407,9 +1407,20 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - - # return [(self.map_tensor_name(name), data_torch)] + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() + + return [(self.map_tensor_name(name), data_torch)] @Model.register("GrokForCausalLM") class GrokModel(Model): diff --git a/llama.cpp b/llama.cpp index 9891ea958684b..5dc25e3de49e5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6833,8 +6833,6 @@ static struct ggml_tensor * llm_build_qbitlinear( struct ggml_tensor * cur) { return ggml_bitlinear_quant(ctx, cur); - - return cur; } static struct ggml_tensor * llm_build_ffn( From 1f2e0ee01226bf3fd1717f5b8e1c9b688c216f83 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 6 Jun 2024 12:28:11 +0800 Subject: [PATCH 03/32] finish bitnet e2e --- convert-hf-to-gguf.py | 21 ++++++++++++++------- llama.cpp | 12 ++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8ba7119fc675e..d39bb3bd1d96a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1398,14 +1398,21 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - hparams = self.hparams - self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_name("Bitnet") + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) def weight_quant(self, weight): dtype = weight.dtype diff --git a/llama.cpp b/llama.cpp index 5dc25e3de49e5..38dbf31e0ed5d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11587,16 +11587,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); From 5e59660173d050628647f9a8c6f96830f57ce052 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Fri, 7 Jun 2024 14:42:52 +0800 Subject: [PATCH 04/32] finish f16 hf bitnet e2e --- convert-hf-to-gguf.py | 124 +++++++++++++++++++++- ggml-common.h | 67 ++++++++++++ ggml-quants.c | 22 ++++ ggml-quants.h | 1 + ggml.c | 202 +++++++++++++++++++++++++++++++++++- ggml.h | 1 + gguf-py/gguf/constants.py | 3 + gguf-py/gguf/gguf_writer.py | 13 ++- llama.cpp | 17 +-- llama.h | 1 + 10 files changed, 440 insertions(+), 11 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d39bb3bd1d96a..42d67aca4547c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1421,7 +1421,31 @@ def weight_quant(self, weight): result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) + def transform_to_i2(self, x): + from gguf.lazy import LazyNumpyTensor + x = LazyNumpyTensor.to_eager(x) + x_num = np.prod(x.shape) + x = np.reshape(x, x_num) + scale = 1 + for i in range(x_num): + if x[i] != 0: + scale = x[i] + break + x = np.divide(x, scale) + x = x.astype(np.uint8) + x = np.reshape(x, [x.shape[0] // 4, 4]) + keep_bit = {0:192, 1:48, 2:12, 3:3} + ans = np.zeros([x_num // 4], dtype=np.uint8) + for i in range(4): + x_bit_col = x[:, i] + x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2) + x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i]) + ans = np.bitwise_or(ans, x_bit_shift) + scale = np.tile(scale, 8) + return ans, scale + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # quant weight to i2 (in fp16) if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): @@ -1429,6 +1453,103 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + def write_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): + data: np.ndarray = data # type hint + data_shape = data.shape + n_dims = len(data.shape) + data_dtype = data.dtype + data_qtype: gguf.GGMLQuantizationType | None = None + + # when both are True, f32 should win + extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) + extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + )) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) + + suit_i2 = True + if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): + suit_i2 = False + + i2_scale = None + if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2: + data, i2_scale = self.transform_to_i2(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.I2 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data = gguf.quantize_bf16(data) + assert data.dtype == np.int16 + data_qtype = gguf.GGMLQuantizationType.BF16 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): + data = gguf.quantize_q8_0(data) + assert data.dtype == np.uint8 + data_qtype = gguf.GGMLQuantizationType.Q8_0 + + else: # default to float16 for quantized tensors + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 + + if data_qtype is None: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + shape = data_shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) + if i2_scale is not None: + self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK @@ -2804,7 +2925,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -2864,6 +2985,7 @@ def main() -> None: "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "i2" : gguf.LlamaFileType.MOSTLY_I2, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/ggml-common.h b/ggml-common.h index 77e6bfba4b11b..a3d4f7a56a03b 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1016,6 +1016,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() +GGML_TABLE_BEGIN(uint32_t, i2_q8, 256) +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00010101, 0x01010101, 0x00010101, 0xff010101, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml-quants.c b/ggml-quants.c index 9f864e5c479ea..a132113076aab 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3726,6 +3726,28 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +//====================================== I2 =============================================== + +void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + const uint8_t * restrict x = vx; + const int8_t * restrict y = vy; + + int sumi = 0; + + for (int i = 0; i < n / 4; i++) { + int8_t* weight = (const int8_t *)(i2_q8 + x[i]); + sumi += (int)y[i*4+0] * weight[0]; + sumi += (int)y[i*4+1] * weight[1]; + sumi += (int)y[i*4+2] * weight[2]; + sumi += (int)y[i*4+3] * weight[3]; + } + *s = (float)(sumi); + +} + void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml-quants.h b/ggml-quants.h index 4d436a8f06b3e..1c8e3839d7166 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -99,6 +99,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index 4c3e6f72371f6..335cb1cddbf30 100644 --- a/ggml.c +++ b/ggml.c @@ -569,6 +569,15 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I2] = { + .type_name = "i2", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -1805,6 +1814,7 @@ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_mul_f32_bitnet (const int n, float * y, const float x) { for (int i = 0; i < n; ++i) y[i] = y[i] * x; } static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -2636,6 +2646,16 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl s[i] /= scale; } } +inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { + + for (int i = 0; i < n; ++i) { + s[i] = round(s[i] * scale); + if (s[i] > max) s[i] = max; + if (s[i] < min) s[i] = min; + inp[i] = (int8_t)(s[i]); + } + +} // // data types @@ -3081,6 +3101,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } + if(tensor->type == 31){ + nbytes = nbytes / 4 + 32; + } + } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -12411,7 +12435,10 @@ static void ggml_compute_forward_mul_mat_one_chunk( } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); + size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (src0->type == 31) { + row_size = ne10; + } assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -12425,6 +12452,9 @@ static void ggml_compute_forward_mul_mat_one_chunk( // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; + uint8_t *i_weight = (uint8_t*) (src0->data); + float * scale = (float * )((i_weight) + (ne00 * ne01 / 4)); + float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12458,9 +12488,15 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + if (src0->type == 31) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + tmp[ir0 - iir0] = tmp[ir0 - iir0] * (*scale) * (act_scales[i11]); + }else { vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } + } + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } @@ -12469,6 +12505,164 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } + +static void ggml_compute_forward_bitnet_mul_mat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + struct ggml_compute_state * state) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + const bool src1_cont = ggml_is_contiguous(src1); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + UNUSED(r2); + UNUSED(r3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + if (params->type == GGML_TASK_TYPE_INIT) { + if (ith != 0) { + return; + } + atomic_store(&state->shared->current_chunk, nth); + char * wdata = params->wdata; + float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + for (int64_t i11 = 0; i11 < ne11; i11++) { + float rowmax = 0.00001; + ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); + float s = 127 / rowmax; + act_scales[i11] = 1/s; + ggml_vec_scaleroundclamp_f32_v2(ne10, + (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), + (int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)), + s, -128, 127); + } + } + } + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + // atomic_store(&state->shared->current_chunk, nth); + // // char * wdata = params->wdata; + // const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10); + // // printf("vec_dot_type:%d\n", vec_dot_type); + // // printf("row_size:%ld\n", row_size); + // assert(params->wsize >= ne11*ne12*ne13*row_size); + // GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // for (int64_t i13 = 0; i13 < ne13; ++i13) { + // for (int64_t i12 = 0; i12 < ne12; ++i12) { + // for (int64_t i11 = 0; i11 < ne11; ++i11) { + // quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + // wdata += row_size; + // } + // } + // } + + + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = 1; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + // Now select a reasonable chunk size. + int chunk_size = 16; + + // We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } + + // distribute the work across the inner or outer loop based on which one is larger + // The number of chunks in the 0/1 dim. + // CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + + // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. + if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } + + // The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + //if (ith == 0) + // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); + + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + + if (nth >= nchunk0 * nchunk1) { + break; + } + + current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); + } + +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -12482,6 +12676,11 @@ static void ggml_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS + if (src0->type == 31) { + ggml_compute_forward_bitnet_mul_mat(params, dst, state); + return; + } + const int ith = params->ith; const int nth = params->nth; @@ -14349,6 +14548,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_I2: case GGML_TYPE_COUNT: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 98ef961323e2e..5d540aa305c38 100644 --- a/ggml.h +++ b/ggml.h @@ -377,6 +377,7 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, + GGML_TYPE_I2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 429f3818914c9..5e94edb22fa2b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -925,6 +925,7 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 + I2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -966,6 +967,7 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_I2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1032,6 +1034,7 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I2: (1, 1), GGMLQuantizationType.I8: (1, 1), GGMLQuantizationType.I16: (1, 2), GGMLQuantizationType.I32: (1, 4), diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff58b3..2d19cd44c2412 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -225,8 +225,10 @@ def add_tensor_info( dtype = GGMLQuantizationType.I32 elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 + elif tensor_dtype == np.uint8: + dtype = GGMLQuantizationType.I2 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now") else: dtype = raw_dtype if tensor_dtype == np.uint8: @@ -237,7 +239,10 @@ def add_tensor_info( self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("Q", self.offset_tensor) - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + if dtype == GGMLQuantizationType.I2: + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment + else: + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.ti_data_count += 1 def add_tensor( @@ -252,7 +257,9 @@ def add_tensor( self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + + if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")): + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) if self.temp_file is None: self.tensors.append(tensor) diff --git a/llama.cpp b/llama.cpp index 38dbf31e0ed5d..53f0473333d24 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3192,8 +3192,9 @@ struct llama_model_loader { llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = gguf_find_tensor(gguf_ctx, name); + printf("name:%s\n", name); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); - + printf("offs:%ld\n", offs + ggml_nbytes(tensor)); if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) { throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); } @@ -7140,7 +7141,7 @@ static struct ggml_tensor * llm_build_kqv( cb(cur, "attn_sub_norm", il); // B2 for wo - cur = llm_build_qbitlinear(ctx, cur); + // cur = llm_build_qbitlinear(ctx, cur); } ggml_build_forward_expand(graph, cur); @@ -11563,7 +11564,7 @@ struct llm_build_context { { // compute Q and K and RoPE them // B1.Q - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -11635,7 +11636,7 @@ struct llm_build_context { // cb(cur, "ffn_out", il); - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); @@ -11658,7 +11659,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); // B4 for w2 - cur = llm_build_qbitlinear(ctx0, cur); + // cur = llm_build_qbitlinear(ctx0, cur); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); cb(cur, "ffn_down", il); @@ -15684,6 +15685,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_I2 : default_type = GGML_TYPE_I2; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -15921,7 +15923,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } - + if (tensor->type == 31) { + // no need quantize for i2 + new_type = tensor->type; + } // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; diff --git a/llama.h b/llama.h index a78ccdaf557d0..1a225fa618a32 100644 --- a/llama.h +++ b/llama.h @@ -156,6 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors + LLAMA_FTYPE_MOSTLY_I2 = 33, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; From 2a01a7ce0d12c51f6672b73013db6442f9cf9577 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Fri, 7 Jun 2024 18:29:59 +0800 Subject: [PATCH 05/32] remove unsed --- ggml-quants.c | 3 --- ggml.c | 58 +++++++++++++++++++++++++++++++++------------------ llama.cpp | 2 -- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index a132113076aab..8c3daf3328293 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3729,9 +3729,6 @@ static inline __m128i get_scale_shuffle(int i) { //====================================== I2 =============================================== void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - const uint8_t * restrict x = vx; const int8_t * restrict y = vy; diff --git a/ggml.c b/ggml.c index 335cb1cddbf30..fcc5ed09b8448 100644 --- a/ggml.c +++ b/ggml.c @@ -1814,7 +1814,6 @@ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } -inline static void ggml_vec_mul_f32_bitnet (const int n, float * y, const float x) { for (int i = 0; i < n; ++i) y[i] = y[i] * x; } static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -12434,7 +12433,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( return; } - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; size_t row_size = ggml_row_size(vec_dot_type, ne10); if (src0->type == 31) { row_size = ne10; @@ -12454,7 +12453,17 @@ static void ggml_compute_forward_mul_mat_one_chunk( float tmp[32]; uint8_t *i_weight = (uint8_t*) (src0->data); float * scale = (float * )((i_weight) + (ne00 * ne01 / 4)); - float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + float * act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + // printf("src0->name:%s\n", src0->name); + // printf("src1->name:%s\n", src1->name); + // printf("ne03:%ld\n", ne03); + // printf("ne02:%ld\n", ne02); + // printf("ne01:%ld\n", ne01); + // printf("ne00:%ld\n", ne00); + // printf("ne13:%ld\n", ne13); + // printf("ne12:%ld\n", ne12); + // printf("ne11:%ld\n", ne11); + // printf("ne10:%ld\n", ne10); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12472,7 +12481,9 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t i3 = i13; const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - + // if (src0->type == 31) { + // printf("src0->%ld\n", (0 + i02 * nb02 + i03 * nb03)); + // } // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using // the original src1 data pointer, so we should index using the indices directly @@ -12481,22 +12492,29 @@ static void ggml_compute_forward_mul_mat_one_chunk( (src1_cont || src1->type != vec_dot_type ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size : (i11 * nb11 + i12 * nb12 + i13 * nb13)); + // if (src0->type == 31) { + // printf("src1->%ld\n", (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size); + // } float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} - + // if (src0->type == 31) { + // printf("dst->%ld\n", (i1 * nb1 + i2 * nb2 + i3 * nb3)); + // } for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { if (src0->type == 31) { + // printf("row->%ld\n", (ir0 * nb01 / 4)); vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - tmp[ir0 - iir0] = tmp[ir0 - iir0] * (*scale) * (act_scales[i11]); - }else { + tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); + } else { vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } } - + // printf("num_rows_per_vec_dot->%ld\n", num_rows_per_vec_dot); + // printf("iir0->%ld\n", iir0); for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } @@ -12552,23 +12570,23 @@ static void ggml_compute_forward_bitnet_mul_mat( if (ith != 0) { return; } - atomic_store(&state->shared->current_chunk, nth); - char * wdata = params->wdata; - float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - for (int64_t i11 = 0; i11 < ne11; i11++) { - float rowmax = 0.00001; - ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); - float s = 127 / rowmax; - act_scales[i11] = 1/s; - ggml_vec_scaleroundclamp_f32_v2(ne10, + atomic_store(&state->shared->current_chunk, nth); + char * wdata = params->wdata; + float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + for (int64_t i11 = 0; i11 < ne11; i11++) { + float rowmax = 0.00001; + ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); + float s = 127 / rowmax; + act_scales[i11] = s; + ggml_vec_scaleroundclamp_f32_v2(ne10, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), (int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)), s, -128, 127); + } } } - } // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. // atomic_store(&state->shared->current_chunk, nth); // // char * wdata = params->wdata; diff --git a/llama.cpp b/llama.cpp index 53f0473333d24..170fe550a5806 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3192,9 +3192,7 @@ struct llama_model_loader { llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = gguf_find_tensor(gguf_ctx, name); - printf("name:%s\n", name); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); - printf("offs:%ld\n", offs + ggml_nbytes(tensor)); if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) { throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); } From 4e1ab5062856f6a6b5910ab9935bea571941c802 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Sat, 8 Jun 2024 12:44:13 +0000 Subject: [PATCH 06/32] finish bitnet i2 e2e --- convert-hf-to-gguf.py | 2 +- ggml-quants.c | 2 +- ggml.c | 133 ++---------- ggml.h | 7 - llama.cpp | 27 +-- tokenization_bitnet.py | 482 ----------------------------------------- 6 files changed, 16 insertions(+), 637 deletions(-) delete mode 100644 tokenization_bitnet.py diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 42d67aca4547c..ea993d720cde5 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1431,7 +1431,7 @@ def transform_to_i2(self, x): if x[i] != 0: scale = x[i] break - x = np.divide(x, scale) + x = np.where(x * scale > 0, 1, np.where(x * scale < 0, -1, x)) x = x.astype(np.uint8) x = np.reshape(x, [x.shape[0] // 4, 4]) keep_bit = {0:192, 1:48, 2:12, 3:3} diff --git a/ggml-quants.c b/ggml-quants.c index 8c3daf3328293..1353671ccad8e 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3741,7 +3741,7 @@ void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * res sumi += (int)y[i*4+2] * weight[2]; sumi += (int)y[i*4+3] * weight[3]; } - *s = (float)(sumi); + *s = (float)sumi; } diff --git a/ggml.c b/ggml.c index fcc5ed09b8448..06aa601b2ad33 100644 --- a/ggml.c +++ b/ggml.c @@ -2630,7 +2630,7 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } -inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, const float * x, float min) { +inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, float * x, float min) { float max = min; for (int i = 0; i < n; ++i) { max = MAX(max, fabs(x[i])); @@ -2646,12 +2646,12 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl } } inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { - + float temp; for (int i = 0; i < n; ++i) { - s[i] = round(s[i] * scale); - if (s[i] > max) s[i] = max; - if (s[i] < min) s[i] = min; - inp[i] = (int8_t)(s[i]); + temp = round(s[i] * scale); + if (temp > max) temp = max; + if (temp < min) temp = min; + inp[i] = (int8_t)(temp); } } @@ -2745,10 +2745,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", - "BITLINEAR_QUANT" }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2835,10 +2834,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", - "bitlinear(x)", }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4873,28 +4871,6 @@ struct ggml_tensor * ggml_mean( return result; } -// ggml_bitlinear_quant for bitnet - -struct ggml_tensor * ggml_bitlinear_quant( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement - is_node = true; - } - - int64_t ne[GGML_MAX_DIMS] = { a->ne[0], a->ne[1], a->ne[2], a->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, ggml_n_dims(a), ne); - - result->op = GGML_OP_BITLINEAR_QUANT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - // ggml_argmax struct ggml_tensor * ggml_argmax( @@ -10805,62 +10781,6 @@ static void ggml_compute_forward_mean( } } -static void ggml_compute_forward_bitlinear_quant_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - assert(ne0 == ne00); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - float rowmax = 0.00001; - ggml_vec_absmaxclamp_f32(ne00, &rowmax, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), 0.00001); - float s = 127 / rowmax; - - ggml_vec_scaleroundclamp_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - s, -128, 127); - } - } - } -} - -static void ggml_compute_forward_bitlinear_quant( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_bitlinear_quant_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_argmax static void ggml_compute_forward_argmax_f32( @@ -12453,17 +12373,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( float tmp[32]; uint8_t *i_weight = (uint8_t*) (src0->data); float * scale = (float * )((i_weight) + (ne00 * ne01 / 4)); - float * act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); - // printf("src0->name:%s\n", src0->name); - // printf("src1->name:%s\n", src1->name); - // printf("ne03:%ld\n", ne03); - // printf("ne02:%ld\n", ne02); - // printf("ne01:%ld\n", ne01); - // printf("ne00:%ld\n", ne00); - // printf("ne13:%ld\n", ne13); - // printf("ne12:%ld\n", ne12); - // printf("ne11:%ld\n", ne11); - // printf("ne10:%ld\n", ne10); + float * act_scales = (float*) ((char *) wdata + (ne11 * ne10)); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12481,9 +12391,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t i3 = i13; const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - // if (src0->type == 31) { - // printf("src0->%ld\n", (0 + i02 * nb02 + i03 * nb03)); - // } + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using // the original src1 data pointer, so we should index using the indices directly @@ -12492,17 +12400,13 @@ static void ggml_compute_forward_mul_mat_one_chunk( (src1_cont || src1->type != vec_dot_type ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - // if (src0->type == 31) { - // printf("src1->%ld\n", (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size); - // } + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} - // if (src0->type == 31) { - // printf("dst->%ld\n", (i1 * nb1 + i2 * nb2 + i3 * nb3)); - // } + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { if (src0->type == 31) { // printf("row->%ld\n", (ir0 * nb01 / 4)); @@ -12513,8 +12417,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } - // printf("num_rows_per_vec_dot->%ld\n", num_rows_per_vec_dot); - // printf("iir0->%ld\n", iir0); for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } @@ -12572,7 +12474,7 @@ static void ggml_compute_forward_bitnet_mul_mat( } atomic_store(&state->shared->current_chunk, nth); char * wdata = params->wdata; - float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4)); + float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { for (int64_t i11 = 0; i11 < ne11; i11++) { @@ -17634,10 +17536,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mean(params, tensor); } break; - case GGML_OP_BITLINEAR_QUANT: - { - ggml_compute_forward_bitlinear_quant(params, tensor->src[0], tensor); - } break; case GGML_OP_ARGMAX: { ggml_compute_forward_argmax(params, tensor); @@ -18804,10 +18702,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_BITLINEAR_QUANT: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_ARGSORT: { GGML_ASSERT(false); // TODO: not implemented @@ -19573,7 +19467,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ case GGML_OP_GET_REL_POS: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: - case GGML_OP_BITLINEAR_QUANT: case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM2_F32: case GGML_OP_MAP_CUSTOM3_F32: diff --git a/ggml.h b/ggml.h index 5d540aa305c38..eb9b124879706 100644 --- a/ggml.h +++ b/ggml.h @@ -507,8 +507,6 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, - GGML_OP_BITLINEAR_QUANT, - GGML_OP_COUNT, }; @@ -996,11 +994,6 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // for bitnet - GGML_API struct ggml_tensor * ggml_bitlinear_quant( - struct ggml_context * ctx, - struct ggml_tensor * a); - // argmax along rows GGML_API struct ggml_tensor * ggml_argmax( struct ggml_context * ctx, diff --git a/llama.cpp b/llama.cpp index 170fe550a5806..4db25c45eb773 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6827,13 +6827,6 @@ static struct ggml_tensor * llm_build_norm( return cur; } -static struct ggml_tensor * llm_build_qbitlinear( - struct ggml_context * ctx, - struct ggml_tensor * cur) - { - return ggml_bitlinear_quant(ctx, cur); - } - static struct ggml_tensor * llm_build_ffn( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -7137,9 +7130,7 @@ static struct ggml_tensor * llm_build_kqv( attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_sub_norm", il); - - // B2 for wo - // cur = llm_build_qbitlinear(ctx, cur); + } ggml_build_forward_expand(graph, cur); @@ -11561,8 +11552,6 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - // B1.Q - // cur = llm_build_qbitlinear(ctx0, cur); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -11625,17 +11614,6 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - // cur = llm_build_ffn(ctx0, cur, - // model.layers[il].ffn_up, NULL, - // model.layers[il].ffn_gate, NULL, - // model.layers[il].ffn_down, NULL, - // NULL, - // LLM_FFN_SILU, LLM_FFN_PAR, cb, il, hparams, model.layers[il].ffn_sub_norm, isbitnet); - // cb(cur, "ffn_out", il); - - - // cur = llm_build_qbitlinear(ctx0, cur); - struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); cb(tmp, "ffn_up", il); @@ -11656,9 +11634,6 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_sub_norm", il); - // B4 for w2 - // cur = llm_build_qbitlinear(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); cb(cur, "ffn_down", il); diff --git a/tokenization_bitnet.py b/tokenization_bitnet.py deleted file mode 100644 index 09b482f72f2cd..0000000000000 --- a/tokenization_bitnet.py +++ /dev/null @@ -1,482 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tokenization classes for LLaMA.""" -import os -from shutil import copyfile -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import sentencepiece as spm - -from transformers.convert_slow_tokenizer import import_protobuf -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging - - -if TYPE_CHECKING: - from transformers.tokenization_utils_base import TextInput - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", - }, - "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", - }, -} -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "hf-internal-testing/llama-tokenizer": 2048, -} -SPIECE_UNDERLINE = "▁" - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ - that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class BitnetTokenizer(PreTrainedTokenizer): - """ - Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is - no padding token in the original model. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The end of sequence token. - pad_token (`str` or `tokenizers.AddedToken`, *optional*): - A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by - attention mechanisms or loss computation. - sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): - Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for - SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, - to set: - - - `enable_sampling`: Enable subword regularization. - - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. - - - `nbest_size = {0,1}`: No sampling is performed. - - `nbest_size > 1`: samples from the nbest_size results. - - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) - using forward-filtering-and-backward-sampling algorithm. - - - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for - BPE-dropout. - - add_bos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add an `bos_token` at the start of sequences. - add_eos_token (`bool`, *optional*, defaults to `False`): - Whether or not to add an `eos_token` at the end of sequences. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like - extra spaces. - use_default_system_prompt (`bool`, *optional*, defaults to `False`): - Whether or not the default system prompt for Bitnet should be used. - spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to add spaces between special tokens. - legacy (`bool`, *optional*): - Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 - and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple - example: - - - `legacy=True`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) - >>> tokenizer.encode("Hello .") - [8774, 32099, 3, 5, 1] - ``` - - `legacy=False`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) - >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here - [8774, 32099, 5, 1] - ``` - Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. - add_prefix_space (`bool`, *optional*, defaults to `True`): - Whether or not to add an initial space to the input. This allows to treat the leading word just as any - other word. - - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file, - unk_token="", - bos_token="", - eos_token="", - pad_token=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - use_default_system_prompt=False, - spaces_between_special_tokens=False, - legacy=None, - add_prefix_space=True, - **kwargs, - ): - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token - - if legacy is None: - logger.warning_once( - f"You are using the default legacy behaviour of the {self.__class__}. This is" - " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." - " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" - " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565" - ) - legacy = True - - self.legacy = legacy - self.vocab_file = vocab_file - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - self.use_default_system_prompt = use_default_system_prompt - self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) - self.add_prefix_space = add_prefix_space - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - use_default_system_prompt=use_default_system_prompt, - spaces_between_special_tokens=spaces_between_special_tokens, - legacy=legacy, - add_prefix_space=add_prefix_space, - **kwargs, - ) - - @property - def unk_token_length(self): - return len(self.sp_model.encode(str(self.unk_token))) - - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor - def get_spm_processor(self, from_slow=False): - tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) - if self.legacy or from_slow: # no dependency on protobuf - tokenizer.Load(self.vocab_file) - return tokenizer - - with open(self.vocab_file, "rb") as f: - sp_model = f.read() - model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") - model = model_pb2.ModelProto.FromString(sp_model) - normalizer_spec = model_pb2.NormalizerSpec() - normalizer_spec.add_dummy_prefix = False - model.normalizer_spec.MergeFrom(normalizer_spec) - sp_model = model.SerializeToString() - tokenizer.LoadFromSerializedProto(sp_model) - return tokenizer - - def __getstate__(self): - state = self.__dict__.copy() - state["sp_model"] = None - state["sp_model_proto"] = self.sp_model.serialized_model_proto() - return state - - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.LoadFromSerializedProto(self.sp_model_proto) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_model.get_piece_size() - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize - def tokenize(self, text: "TextInput", **kwargs) -> List[str]: - """ - Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the - first token is special. - """ - if self.legacy or len(text) == 0: - return super().tokenize(text, **kwargs) - - text = text.replace(SPIECE_UNDERLINE, " ") - if self.add_prefix_space: - text = SPIECE_UNDERLINE + text - - tokens = super().tokenize(text, **kwargs) - - if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: - tokens = tokens[1:] - return tokens - - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize - def _tokenize(self, text, **kwargs): - """ - Returns a tokenized string. - - We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any - SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give - `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the - `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. - `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. - """ - tokens = self.sp_model.encode(text, out_type=str) - if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): - return tokens - - # 1. Encode string + prefix ex: " Hey" - tokens = self.sp_model.encode(self.unk_token + text, out_type=str) - # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - token = self.sp_model.IdToPiece(index) - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - # since we manually add the prefix space, we have to remove it when decoding - if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: - tokens[0] = tokens[0][1:] - - current_sub_tokens = [] - out_string = "" - prev_is_special = False - for i, token in enumerate(tokens): - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - if not prev_is_special and i != 0 and self.legacy: - out_string += " " - out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - prev_is_special = False - out_string += self.sp_model.decode(current_sub_tokens) - return out_string - - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, "wb") as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file,) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT - sequence pair mask has the following format: - - ``` - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 - | first sequence | second sequence | - ``` - - if token_ids_1 is None, only returns the first portion of the mask (0s). - - Args: - token_ids_0 (`List[int]`): - List of ids. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). - """ - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) - - if token_ids_1 is not None: - output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) - - return output - - @property - def default_chat_template(self): - """ - LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. - Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict - user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering - rather than needing special tokens. The system message is partly 'embedded' in the first user message, which - results in an unusual token ordering when it is present. This template should definitely be changed if you wish - to fine-tune a model with more flexible role ordering! - - The output should look something like: - - [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - - The reference for this chat template is [this code - snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) - in the original repository. - """ - logger.warning_once( - "\nNo chat template is defined for this tokenizer - using the default template " - f"for the {self.__class__.__name__} class. If the default is not appropriate for " - "your model, please set `tokenizer.chat_template` to an appropriate template. " - "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" - ) - template = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" # Extract system message if it's present - "{% set system_message = messages[0]['content'] %}" - "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" - "{% set loop_messages = messages %}" # Or use the default system message if the flag is set - "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" # Loop over all non-system messages - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message - "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'system' %}" - "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") - default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") - template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - - return template \ No newline at end of file From ca0908559308553472f56ae58d7783ef0442cae6 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Sun, 9 Jun 2024 02:43:38 +0000 Subject: [PATCH 07/32] move i2s to quantize v1 --- convert-hf-to-gguf.py | 19 ++++++++++++------- examples/quantize/quantize.cpp | 1 + ggml-quants.c | 27 +++++++++++++++++++++++++++ ggml-quants.h | 1 + ggml.c | 6 ++++-- llama.cpp | 6 +----- 6 files changed, 46 insertions(+), 14 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ea993d720cde5..735630b9c8933 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1418,6 +1418,10 @@ def weight_quant(self, weight): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) + # from gguf.lazy import LazyNumpyTensor + # np_s = LazyNumpyTensor.to_eager(s.numpy()) + + # print(np_s) result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) @@ -1444,14 +1448,15 @@ def transform_to_i2(self, x): scale = np.tile(scale, 8) return ans, scale - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # quant weight to i2 (in fp16) - if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", - "down_proj.weight", "up_proj.weight", "gate_proj.weight", - "o_proj.weight")): - data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() + # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # # quant weight to i2 (in fp16) + # if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + # "down_proj.weight", "up_proj.weight", "gate_proj.weight", + # "o_proj.weight")): + # print(name) + # data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() - return [(self.map_tensor_name(name), data_torch)] + # return [(self.map_tensor_name(name), data_torch)] def write_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 28584e14b788c..bc2cc24359d87 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,6 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, + { "I2_S", LLAMA_FTYPE_MOSTLY_I2, " 2 bpw per-tensor", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-quants.c b/ggml-quants.c index 1353671ccad8e..96d3c88f620c5 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3306,6 +3306,33 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + // 2 bits per weight + size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row) / 4; + char * qrow = (char *)dst; + printf("n_row:%d\n", nrow); + printf("n_per_row:%d\n", n_per_row); + int n = nrow * n_per_row; + float accu = 0.0; + float min = 0.00001; + for (int i = 0; i < n; ++i) { + accu += fabs(src[i]); + } + accu = accu > min ? accu : min; + float scale = n / accu; + + printf("\nscale:%f\n", scale); + + // for (int64_t row = 0; row < nrow; ++row) { + // quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + // src += n_per_row; + // qrow += row_size; + // } + + // 32B for scale + return nrow * row_size + 32; +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { diff --git a/ggml-quants.h b/ggml-quants.h index 1c8e3839d7166..fea0b41ad2382 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -122,6 +122,7 @@ size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_i2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void iq2xs_init_impl(enum ggml_type type); void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml.c b/ggml.c index 06aa601b2ad33..378042537e22b 100644 --- a/ggml.c +++ b/ggml.c @@ -573,7 +573,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_name = "i2", .blck_size = 1, .type_size = sizeof(int8_t), - .is_quantized = false, + .is_quantized = true, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -2637,6 +2637,7 @@ inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, float * x, f } *s = max; } + inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) { for (int i = 0; i < n; ++i) { s[i] = round(x[i] * scale); @@ -2645,6 +2646,7 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl s[i] /= scale; } } + inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { float temp; for (int i = 0; i < n; ++i) { @@ -2653,7 +2655,6 @@ inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_ if (temp < min) temp = min; inp[i] = (int8_t)(temp); } - } // @@ -21726,6 +21727,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_I2: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/llama.cpp b/llama.cpp index 4db25c45eb773..109ac4034304a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15634,6 +15634,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_I2: default_type = GGML_TYPE_I2; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: @@ -15658,7 +15659,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; - case LLAMA_FTYPE_MOSTLY_I2 : default_type = GGML_TYPE_I2; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -15896,10 +15896,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } - if (tensor->type == 31) { - // no need quantize for i2 - new_type = tensor->type; - } // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; From dbee0a86c1ea144b3f9546e964cee9d7151498e9 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 9 Jun 2024 18:20:32 +0800 Subject: [PATCH 08/32] move i2 to quantize --- convert-hf-to-gguf.py | 139 +++--------------------------------------- ggml-quants.c | 53 ++++++++++------ ggml.c | 6 +- 3 files changed, 48 insertions(+), 150 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 735630b9c8933..d98967e25d704 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1418,142 +1418,17 @@ def weight_quant(self, weight): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) - # from gguf.lazy import LazyNumpyTensor - # np_s = LazyNumpyTensor.to_eager(s.numpy()) - - # print(np_s) result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) - def transform_to_i2(self, x): - from gguf.lazy import LazyNumpyTensor - x = LazyNumpyTensor.to_eager(x) - x_num = np.prod(x.shape) - x = np.reshape(x, x_num) - scale = 1 - for i in range(x_num): - if x[i] != 0: - scale = x[i] - break - x = np.where(x * scale > 0, 1, np.where(x * scale < 0, -1, x)) - x = x.astype(np.uint8) - x = np.reshape(x, [x.shape[0] // 4, 4]) - keep_bit = {0:192, 1:48, 2:12, 3:3} - ans = np.zeros([x_num // 4], dtype=np.uint8) - for i in range(4): - x_bit_col = x[:, i] - x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2) - x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i]) - ans = np.bitwise_or(ans, x_bit_shift) - scale = np.tile(scale, 8) - return ans, scale - - # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # # quant weight to i2 (in fp16) - # if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", - # "down_proj.weight", "up_proj.weight", "gate_proj.weight", - # "o_proj.weight")): - # print(name) - # data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() - - # return [(self.map_tensor_name(name), data_torch)] - - def write_tensors(self): - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): - continue - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - # use the first number-like part of the tensor name as the block id - bid = None - for part in name.split("."): - if part.isdecimal(): - bid = int(part) - break - - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray = data # type hint - data_shape = data.shape - n_dims = len(data.shape) - data_dtype = data.dtype - data_qtype: gguf.GGMLQuantizationType | None = None - - # when both are True, f32 should win - extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) - extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) - - # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors - # Conditions should closely match those in llama_model_quantize_internal in llama.cpp - extra_f32 = any(cond for cond in ( - extra_f32, - n_dims == 1, - new_name.endswith("_norm.weight"), - )) - - # Some tensor types are always in float32 - extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( - gguf.MODEL_TENSOR.FFN_GATE_INP, - gguf.MODEL_TENSOR.POS_EMBD, - gguf.MODEL_TENSOR.TOKEN_TYPES, - )) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - extra_f16 = any(cond for cond in ( - extra_f16, - (name.endswith(".weight") and n_dims >= 2), - )) - - suit_i2 = True - if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): - suit_i2 = False - - i2_scale = None - if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_I2 and suit_i2: - data, i2_scale = self.transform_to_i2(data) - assert data.dtype == np.uint8 - assert i2_scale.dtype == np.float32 - data_qtype = gguf.GGMLQuantizationType.I2 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - data = gguf.quantize_bf16(data) - assert data.dtype == np.int16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): - data = gguf.quantize_q8_0(data) - assert data.dtype == np.uint8 - data_qtype = gguf.GGMLQuantizationType.Q8_0 - - else: # default to float16 for quantized tensors - if data_dtype != np.float16: - data = data.astype(np.float16) - data_qtype = gguf.GGMLQuantizationType.F16 - - if data_qtype is None: # by default, convert to float32 - if data_dtype != np.float32: - data = data.astype(np.float32) - data_qtype = gguf.GGMLQuantizationType.F32 - - shape = data_shape - # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape - # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" - - # n_dims is implicit in the shape - logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # transform weight into 1/0/-1 (in fp32) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() - self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) - if i2_scale is not None: - self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + return [(self.map_tensor_name(name), data_torch)] @Model.register("GrokForCausalLM") class GrokModel(Model): diff --git a/ggml-quants.c b/ggml-quants.c index 96d3c88f620c5..a4a72c8474c36 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3308,29 +3308,47 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { // 2 bits per weight - size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row) / 4; - char * qrow = (char *)dst; - printf("n_row:%d\n", nrow); - printf("n_per_row:%d\n", n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row); + int n = nrow * n_per_row; - float accu = 0.0; - float min = 0.00001; - for (int i = 0; i < n; ++i) { - accu += fabs(src[i]); + + // f32 -> q8 + double i2_scale = 0; + for (int i=0; i 1e-6) { + i2_scale = src[i]; + } + } + + uint8_t* q8 = (uint8_t*)dst; + for (int i=0; i 0 ? 1 : 3; } - accu = accu > min ? accu : min; - float scale = n / accu; - printf("\nscale:%f\n", scale); + // q8 -> 0, 1, 3 + // | | | + // 0, 1,-1 - // for (int64_t row = 0; row < nrow; ++row) { - // quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); - // src += n_per_row; - // qrow += row_size; - // } + uint8_t* i2_weight = (uint8_t*)dst; + for (int i=0; i Date: Sun, 9 Jun 2024 20:22:03 +0800 Subject: [PATCH 09/32] clean code --- examples/quantize/quantize.cpp | 2 +- ggml-quants.c | 35 ++++- ggml-quants.h | 3 +- ggml.c | 233 ++++----------------------------- ggml.h | 3 +- llama.cpp | 2 +- llama.h | 2 +- 7 files changed, 64 insertions(+), 216 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index bc2cc24359d87..16cfd1717f7d6 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "I2_S", LLAMA_FTYPE_MOSTLY_I2, " 2 bpw per-tensor", }, + { "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-quants.c b/ggml-quants.c index a4a72c8474c36..6a825cd74d99c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx +void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { + int8_t* dst = (int8_t*)y; + double min = 0.00001; + double max = min; + for (int i = 0; i < n; ++i) { + max = MAX(max, (double)fabs(x[i])); + } + float s = 127 / max; + act_scales[0] = s; + float temp; + for (int i = 0; i < n; ++i) { + temp = round(x[i] * s); + if (temp > 127) temp = 127; + if (temp < -128) temp = -128; + dst[i] = (int8_t)(temp); + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3308,7 +3326,9 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { // 2 bits per weight - size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row); + UNUSED(quant_weights); + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); int n = nrow * n_per_row; @@ -3326,7 +3346,7 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr q8[i] = 0; continue; } - q8[i] = src[i] * i2_scale > 0 ? 1 : 3; + q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3; } // q8 -> 0, 1, 3 @@ -3773,14 +3793,19 @@ static inline __m128i get_scale_shuffle(int i) { //====================================== I2 =============================================== -void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const uint8_t * restrict x = vx; const int8_t * restrict y = vy; + UNUSED(bs); + UNUSED(bx); + UNUSED(by); + UNUSED(nrc); + int sumi = 0; for (int i = 0; i < n / 4; i++) { - int8_t* weight = (const int8_t *)(i2_q8 + x[i]); + const int8_t* weight = (const int8_t *)(i2_q8 + x[i]); sumi += (int)y[i*4+0] * weight[0]; sumi += (int)y[i*4+1] * weight[1]; sumi += (int)y[i*4+2] * weight[2]; @@ -14431,7 +14456,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_I64: - case GGML_TYPE_I2: + case GGML_TYPE_I2_S: // nothing to validate break; default: diff --git a/ggml-quants.h b/ggml-quants.h index fea0b41ad2382..a4d0c0cecc5c7 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -51,6 +51,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -99,7 +100,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index f8752c0f87c46..55aa823c89074 100644 --- a/ggml.c +++ b/ggml.c @@ -569,15 +569,6 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { - [GGML_TYPE_I2] = { - .type_name = "i2", - .blck_size = 1, - .type_size = sizeof(int8_t), - .is_quantized = true, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -922,6 +913,21 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + }, + [GGML_TYPE_I2_S] = { + .type_name = "i2_s", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = true, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_i8_s, + .vec_dot_type = GGML_TYPE_I8_S, + .nrows = 1, + }, + [GGML_TYPE_I8_S] = { + .type_name = "i8_s", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = true, } }; @@ -2630,33 +2636,6 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } -inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, float * x, float min) { - float max = min; - for (int i = 0; i < n; ++i) { - max = MAX(max, fabs(x[i])); - } - *s = max; -} - -inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) { - for (int i = 0; i < n; ++i) { - s[i] = round(x[i] * scale); - if (s[i] > max) s[i] = max; - if (s[i] < min) s[i] = min; - s[i] /= scale; - } -} - -inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { - float temp; - for (int i = 0; i < n; ++i) { - temp = round(s[i] * scale); - if (temp > max) temp = max; - if (temp < min) temp = min; - inp[i] = (int8_t)(temp); - } -} - // // data types // @@ -12409,8 +12388,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - if (src0->type == 31) { - // printf("row->%ld\n", (ir0 * nb01 / 4)); + if (src0->type == GGML_TYPE_I2_S) { vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); } else { @@ -12426,164 +12404,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } - -static void ggml_compute_forward_bitnet_mul_mat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - struct ggml_compute_state * state) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - const bool src1_cont = ggml_is_contiguous(src1); - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - UNUSED(r2); - UNUSED(r3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } - atomic_store(&state->shared->current_chunk, nth); - char * wdata = params->wdata; - float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - for (int64_t i11 = 0; i11 < ne11; i11++) { - float rowmax = 0.00001; - ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); - float s = 127 / rowmax; - act_scales[i11] = s; - ggml_vec_scaleroundclamp_f32_v2(ne10, - (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), - (int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)), - s, -128, 127); - } - } - } - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - // atomic_store(&state->shared->current_chunk, nth); - // // char * wdata = params->wdata; - // const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10); - // // printf("vec_dot_type:%d\n", vec_dot_type); - // // printf("row_size:%ld\n", row_size); - // assert(params->wsize >= ne11*ne12*ne13*row_size); - // GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // for (int64_t i13 = 0; i13 < ne13; ++i13) { - // for (int64_t i12 = 0; i12 < ne12; ++i12) { - // for (int64_t i11 = 0; i11 < ne11; ++i11) { - // quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - // wdata += row_size; - // } - // } - // } - - - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const int64_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const int64_t nr1 = ne1 * ne2 * ne3; - - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = 1; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - - // Now select a reasonable chunk size. - int chunk_size = 16; - - // We need to step up the size if it's small - if (nr0 == 1 || nr1 == 1) { - chunk_size = 64; - } - - // distribute the work across the inner or outer loop based on which one is larger - // The number of chunks in the 0/1 dim. - // CEIL(nr0/chunk_size) - int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; - int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - - // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 - // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. - if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { - // distribute the thread work across the inner or outer loop based on which one is larger - nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - } - - // The number of elements in each chunk - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - //if (ith == 0) - // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); - - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; - - while (current_chunk < nchunk0 * nchunk1) { - const int64_t ith0 = current_chunk % nchunk0; - const int64_t ith1 = current_chunk / nchunk0; - - const int64_t ir0_start = dr0 * ith0; - const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - - const int64_t ir1_start = dr1 * ith1; - const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); - - if (nth >= nchunk0 * nchunk1) { - break; - } - - current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); - } - -} - static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -12597,11 +12417,6 @@ static void ggml_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS - if (src0->type == 31) { - ggml_compute_forward_bitnet_mul_mat(params, dst, state); - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -12751,8 +12566,13 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; + if (src0->type == GGML_TYPE_I2_S) { + float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); + quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11); + } else { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } } } } @@ -14469,7 +14289,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: - case GGML_TYPE_I2: + case GGML_TYPE_I2_S: + case GGML_TYPE_I8_S: case GGML_TYPE_COUNT: { GGML_ASSERT(false); @@ -21727,7 +21548,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_I2: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21750,7 +21571,7 @@ size_t ggml_quantize_chunk( assert(false); } - if (type == GGML_TYPE_I2) { + if (type == GGML_TYPE_I2_S) { result = nrows * row_size / 4 + 32; } else { GGML_ASSERT(result == nrows * row_size); diff --git a/ggml.h b/ggml.h index eb9b124879706..9edc84f5a19f4 100644 --- a/ggml.h +++ b/ggml.h @@ -377,7 +377,8 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_I2 = 31, + GGML_TYPE_I2_S = 31, + GGML_TYPE_I8_S = 32, GGML_TYPE_COUNT, }; diff --git a/llama.cpp b/llama.cpp index 109ac4034304a..865011f679d7c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15634,7 +15634,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; - case LLAMA_FTYPE_MOSTLY_I2: default_type = GGML_TYPE_I2; break; + case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/llama.h b/llama.h index 1a225fa618a32..5bfab5e034874 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_I2 = 33, + LLAMA_FTYPE_MOSTLY_I2_S = 33, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; From 3a0f8b06976ac7732490e70770d38be1fe107338 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 9 Jun 2024 21:15:02 +0800 Subject: [PATCH 10/32] clean code 2 --- convert-hf-to-gguf.py | 3 +- ggml.c | 35 ++++++------- gguf-py/gguf/constants.py | 3 -- gguf-py/gguf/gguf_writer.py | 13 ++--- llama.cpp | 97 +++++++++++++++++++++++++++++++------ 5 files changed, 100 insertions(+), 51 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d98967e25d704..15a23ef6cf690 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2805,7 +2805,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "i2", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -2865,7 +2865,6 @@ def main() -> None: "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, - "i2" : gguf.LlamaFileType.MOSTLY_I2, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/ggml.c b/ggml.c index 55aa823c89074..e59093cdf21cc 100644 --- a/ggml.c +++ b/ggml.c @@ -2724,7 +2724,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", - }; static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); @@ -2813,7 +2812,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", - }; static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); @@ -3078,10 +3076,9 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } - if(tensor->type == 31){ + if(tensor->type == GGML_TYPE_I2_S){ nbytes = nbytes / 4 + 32; } - } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -3107,6 +3104,7 @@ GGML_CALL size_t ggml_type_size(enum ggml_type type) { GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) { assert(ne % ggml_blck_size(type) == 0); + if (type == GGML_TYPE_I2_S) ne /= 4; return ggml_type_size(type)*ne/ggml_blck_size(type); } @@ -12333,11 +12331,11 @@ static void ggml_compute_forward_mul_mat_one_chunk( return; } - void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (src0->type == 31) { - row_size = ne10; - } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + // if (src0->type == 31) { + // row_size = ne10; + // } assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -12351,9 +12349,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; - uint8_t *i_weight = (uint8_t*) (src0->data); - float * scale = (float * )((i_weight) + (ne00 * ne01 / 4)); - float * act_scales = (float*) ((char *) wdata + (ne11 * ne10)); + float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); + const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12380,7 +12377,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( (src1_cont || src1->type != vec_dot_type ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { @@ -12388,13 +12384,12 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - if (src0->type == GGML_TYPE_I2_S) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); - } else { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } - + if (src0->type == GGML_TYPE_I2_S) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); + } else { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5e94edb22fa2b..429f3818914c9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -925,7 +925,6 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 - I2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -967,7 +966,6 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_I2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1034,7 +1032,6 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), - GGMLQuantizationType.I2: (1, 1), GGMLQuantizationType.I8: (1, 1), GGMLQuantizationType.I16: (1, 2), GGMLQuantizationType.I32: (1, 4), diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 2d19cd44c2412..b93747aff58b3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -225,10 +225,8 @@ def add_tensor_info( dtype = GGMLQuantizationType.I32 elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 - elif tensor_dtype == np.uint8: - dtype = GGMLQuantizationType.I2 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64, I2 tensors are supported for now") + raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") else: dtype = raw_dtype if tensor_dtype == np.uint8: @@ -239,10 +237,7 @@ def add_tensor_info( self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("Q", self.offset_tensor) - if dtype == GGMLQuantizationType.I2: - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.data_alignment - else: - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.ti_data_count += 1 def add_tensor( @@ -257,9 +252,7 @@ def add_tensor( self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - - if (raw_dtype != GGMLQuantizationType.F32 or not name.endswith("scale")): - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) if self.temp_file is None: self.tensors.append(tensor) diff --git a/llama.cpp b/llama.cpp index 865011f679d7c..0d271c748b108 100644 --- a/llama.cpp +++ b/llama.cpp @@ -262,7 +262,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -3193,6 +3193,7 @@ struct llama_model_loader { llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = gguf_find_tensor(gguf_ctx, name); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) { throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); } @@ -7029,7 +7030,6 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * attn_sub_norm, int32_t n_tokens, int32_t n_kv, float kq_scale, @@ -7124,15 +7124,6 @@ static struct ggml_tensor * llm_build_kqv( cb(cur, "kqv_merged_cont", il); } - if (model.arch == LLM_ARCH_BITNET) - { - cur = llm_build_norm(ctx, cur, hparams, - attn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_sub_norm", il); - - } - ggml_build_forward_expand(graph, cur); cur = ggml_mul_mat(ctx, wo, cur); @@ -7178,7 +7169,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, - q_cur, kq_mask, nullptr, n_tokens, n_kv, kq_scale, cb, il); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -11590,10 +11581,84 @@ struct llm_build_context { cb(Kcur, "Kcur", il); llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, model, hparams, cparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_mask, model.layers[il].attn_sub_norm, n_tokens, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il - ); + + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + struct ggml_tensor * q_cur = Qcur; + struct ggml_tensor * kq_mask = KQ_mask; + float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; + struct ggml_cgraph * graph = gf; + struct ggml_tensor * wo = model.layers[il].wo; + struct ggml_tensor * cur_attn; + struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + cb(q, "q", il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + if (cparams.flash_attn) { + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + 0); + cb(v, "v", il); + + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + + cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); + } + + cur_attn = llm_build_norm(ctx0, cur_attn, hparams, + attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur_attn, "attn_sub_norm", il); + + ggml_build_forward_expand(graph, cur_attn); + + cur = ggml_mul_mat(ctx0, wo, cur_attn); + cb(cur, "kqv_out", il); } From 97d22be58cb0552f0f1cab32cc9930ae1148fb2b Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Sun, 9 Jun 2024 21:22:50 +0800 Subject: [PATCH 11/32] fix codestyle --- convert-hf-to-gguf.py | 7 ++++--- ggml.c | 2 +- llama.cpp | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 15a23ef6cf690..aa52cee64f4b1 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1390,12 +1390,14 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + @Model.register("BitnetForCausalLM") class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET + def set_vocab(self): self._set_vocab_sentencepiece() - + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_name("Bitnet") @@ -1407,9 +1409,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) @@ -1430,6 +1430,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK diff --git a/ggml.c b/ggml.c index e59093cdf21cc..562415d6007e4 100644 --- a/ggml.c +++ b/ggml.c @@ -12349,7 +12349,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; - float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); + const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { diff --git a/llama.cpp b/llama.cpp index 0d271c748b108..c775ae79b3edb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15961,6 +15961,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } + // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; From 344467f2b89bda09bbbc2943250ea032cc78282f Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Mon, 10 Jun 2024 00:00:52 +0800 Subject: [PATCH 12/32] fix code --- convert-hf-to-gguf.py | 2 +- ggml-quants.c | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index aa52cee64f4b1..05207afe5127f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1417,7 +1417,7 @@ def set_gguf_parameters(self): def weight_quant(self, weight): dtype = weight.dtype weight = weight.float() - s = 1 / weight.abs().mean().clamp(min=1e-5) + s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) diff --git a/ggml-quants.c b/ggml-quants.c index 6a825cd74d99c..7deeb367fb396 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3802,6 +3802,10 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res UNUSED(by); UNUSED(nrc); +#if defined(__AVX2__) + // TODO +#else + int sumi = 0; for (int i = 0; i < n / 4; i++) { @@ -3812,7 +3816,7 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res sumi += (int)y[i*4+3] * weight[3]; } *s = (float)sumi; - +#endif } void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { From 65ac3a362710028a9b485b42d123935570ac052e Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Mon, 10 Jun 2024 00:06:09 +0800 Subject: [PATCH 13/32] fix --- ggml-quants.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 7deeb367fb396..273cdee70bf7d 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3802,9 +3802,9 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res UNUSED(by); UNUSED(nrc); -#if defined(__AVX2__) - // TODO -#else +// #if defined(__AVX2__) +// // TODO +// #else int sumi = 0; @@ -3816,7 +3816,7 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res sumi += (int)y[i*4+3] * weight[3]; } *s = (float)sumi; -#endif +// #endif } void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { From abd798d70f2a53ecf47a883b6da2ab030462e281 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Mon, 10 Jun 2024 02:50:14 +0000 Subject: [PATCH 14/32] fix code --- ggml-quants.c | 63 ++++++++++++++++++++++++++++++++++++++++++++++----- ggml.c | 1 - 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 273cdee70bf7d..72149d4a0a55a 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -664,13 +664,13 @@ void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) double min = 0.00001; double max = min; for (int i = 0; i < n; ++i) { - max = MAX(max, (double)fabs(x[i])); + max = MAX(max, (double)fabs((double)x[i])); } float s = 127 / max; act_scales[0] = s; float temp; for (int i = 0; i < n; ++i) { - temp = round(x[i] * s); + temp = round((double)(x[i] * s)); if (temp > 127) temp = 127; if (temp < -128) temp = -128; dst[i] = (int8_t)(temp); @@ -3335,14 +3335,14 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr // f32 -> q8 double i2_scale = 0; for (int i=0; i 1e-6) { - i2_scale = src[i]; + if (fabs((double)(src[i])) > 1e-6) { + i2_scale = (double)src[i]; } } uint8_t* q8 = (uint8_t*)dst; for (int i=0; i Date: Mon, 10 Jun 2024 03:07:38 +0000 Subject: [PATCH 15/32] fix merge --- convert-hf-to-gguf.py | 11 ----------- llama.cpp | 4 ++-- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 937d0e3280be0..d8ae13c06e193 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1406,19 +1406,8 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_name("Bitnet") - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) - self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) def weight_quant(self, weight): dtype = weight.dtype diff --git a/llama.cpp b/llama.cpp index 61d6ae5a3c815..5ebdeb024e599 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11569,14 +11569,14 @@ struct llm_build_context { Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); From de1d5073e428335a61166b8d24f3537641f886b5 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Tue, 11 Jun 2024 10:23:20 +0800 Subject: [PATCH 16/32] remove unused --- ggml.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index a98fc608c4d47..324a7494c5530 100644 --- a/ggml.c +++ b/ggml.c @@ -12275,9 +12275,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - // if (src0->type == 31) { - // row_size = ne10; - // } assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -12291,6 +12288,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( // attempt to reduce false-sharing (does not seem to make a difference) // 16 * 2, accounting for mmla kernels float tmp[32]; + + // for per-tensor quant const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); From f395dd9ca0288976e9ebe200f38a7ea8c7f1a346 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Wed, 12 Jun 2024 14:28:24 +0800 Subject: [PATCH 17/32] change table name --- ggml-common.h | 128 +++++++++++++++++++++++++------------------------- ggml-quants.c | 18 +++---- 2 files changed, 73 insertions(+), 73 deletions(-) diff --git a/ggml-common.h b/ggml-common.h index 1b7b2133bc2f9..66d984b5a5c04 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1022,70 +1022,70 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() -GGML_TABLE_BEGIN(uint32_t, i2_q8, 256) -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00010100, 0x01010100, 0x00010100, 0xff010100, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00010001, 0x01010001, 0x00010001, 0xff010001, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, -0x00000101, 0x01000101, 0x00000101, 0xff000101, -0x00010101, 0x01010101, 0x00010101, 0xff010101, -0x00000101, 0x01000101, 0x00000101, 0xff000101, -0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00010001, 0x01010001, 0x00010001, 0xff010001, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, -0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, -0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, -0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, -0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00010100, 0x01010100, 0x00010100, 0xff010100, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, -0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, -0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, -0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, -0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, -0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, -0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, -0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256) +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00010101, 0x01010101, 0x00010101, 0xff010101, +0x00000101, 0x01000101, 0x00000101, 0xff000101, +0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00010001, 0x01010001, 0x00010001, 0xff010001, +0x00000001, 0x01000001, 0x00000001, 0xff000001, +0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, +0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, +0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00010100, 0x01010100, 0x00010100, 0xff010100, +0x00000100, 0x01000100, 0x00000100, 0xff000100, +0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00010000, 0x01010000, 0x00010000, 0xff010000, +0x00000000, 0x01000000, 0x00000000, 0xff000000, +0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, +0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, +0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, +0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, +0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, +0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, +0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, +0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, +0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, 0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, GGML_TABLE_END() diff --git a/ggml-quants.c b/ggml-quants.c index 72149d4a0a55a..4f37310677bcb 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3807,14 +3807,14 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res // __m256i accu = _mm256_setzero_si256(); // for (int i=0; i Date: Wed, 12 Jun 2024 16:25:46 +0800 Subject: [PATCH 18/32] fix whitespace --- ggml.c | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml.c b/ggml.c index 324a7494c5530..6a53e3dacd207 100644 --- a/ggml.c +++ b/ggml.c @@ -12332,6 +12332,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } } + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } From 7a8961fff56d6cbb870c0bcd678bc5cf88e6d07f Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Fri, 14 Jun 2024 12:30:27 +0800 Subject: [PATCH 19/32] delete redundant --- convert-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d8ae13c06e193..754edf014e1dd 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1421,7 +1421,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): - data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach() + data_torch = self.weight_quant(data_torch) return [(self.map_tensor_name(name), data_torch)] From 95dced07e4843f91e8d50d29060244ebb12d8090 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Sat, 15 Jun 2024 10:10:40 +0800 Subject: [PATCH 20/32] i2_s to absmax --- ggml-quants.c | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 4f37310677bcb..665e381a3a2cd 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3333,12 +3333,11 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr int n = nrow * n_per_row; // f32 -> q8 - double i2_scale = 0; - for (int i=0; i 1e-6) { - i2_scale = (double)src[i]; - } + double max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, (double)fabs((double)src[i])); } + double i2_scale = max; uint8_t* q8 = (uint8_t*)dst; for (int i=0; i Date: Sat, 15 Jun 2024 14:01:26 +0000 Subject: [PATCH 21/32] finish i2_s/i8_s vec_dot x86 simd --- ggml-common.h | 128 +++++++++++++++++++++++++------------------------- ggml-quants.c | 111 +++++++++++++++++++++---------------------- llama.cpp | 2 - 3 files changed, 120 insertions(+), 121 deletions(-) diff --git a/ggml-common.h b/ggml-common.h index 66d984b5a5c04..409fcf29eab5a 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1023,70 +1023,70 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) GGML_TABLE_END() GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256) -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00010100, 0x01010100, 0x00010100, 0xff010100, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00010001, 0x01010001, 0x00010001, 0xff010001, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, -0x00000101, 0x01000101, 0x00000101, 0xff000101, -0x00010101, 0x01010101, 0x00010101, 0xff010101, -0x00000101, 0x01000101, 0x00000101, 0xff000101, -0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00010001, 0x01010001, 0x00010001, 0xff010001, -0x00000001, 0x01000001, 0x00000001, 0xff000001, -0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, -0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, -0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, -0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, -0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00010100, 0x01010100, 0x00010100, 0xff010100, -0x00000100, 0x01000100, 0x00000100, 0xff000100, -0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00010000, 0x01010000, 0x00010000, 0xff010000, -0x00000000, 0x01000000, 0x00000000, 0xff000000, -0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, -0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, -0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, -0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, -0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, -0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, -0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, -0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, -0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, -0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, -0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, -0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, -0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00010100, 0x01010100, 0x00010100, 0xff010100, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00010001, 0x01010001, 0x00010001, 0xff010001, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, + 0x00000101, 0x01000101, 0x00000101, 0xff000101, + 0x00010101, 0x01010101, 0x00010101, 0xff010101, + 0x00000101, 0x01000101, 0x00000101, 0xff000101, + 0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00010001, 0x01010001, 0x00010001, 0xff010001, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, + 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, + 0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, + 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, + 0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00010100, 0x01010100, 0x00010100, 0xff010100, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, + 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, + 0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, + 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, + 0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, + 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, + 0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, + 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, + 0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, GGML_TABLE_END() #define NGRID_IQ1S 2048 diff --git a/ggml-quants.c b/ggml-quants.c index 665e381a3a2cd..4b5209279a9fb 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3799,60 +3799,61 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res UNUSED(by); UNUSED(nrc); -// TODO -// #if defined(__AVX2__) -// __m256i accu = _mm256_setzero_si256(); - -// for (int i=0; i Date: Mon, 17 Jun 2024 20:33:09 +0800 Subject: [PATCH 22/32] i2s->q22 --- convert-hf-to-gguf.py | 38 +++++- examples/quantize/quantize.cpp | 2 +- ggml-common.h | 9 +- ggml-quants.c | 203 +++++++++++++-------------------- ggml-quants.h | 7 +- ggml.c | 60 +++------- ggml.h | 3 +- gguf-py/gguf/constants.py | 24 +++- llama.cpp | 65 ++++++++--- llama.h | 2 +- 10 files changed, 220 insertions(+), 193 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 754edf014e1dd..9a217c1c74f2d 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1413,17 +1413,47 @@ def weight_quant(self, weight): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) - result = (weight * s).round().clamp(-1, 1) / s - return result.type(dtype) + weight = (weight * s).round().clamp(-1, 1) / s + scale = weight.abs().max().unsqueeze(0) + weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype) + weight = torch.sign(weight).type(dtype) + return weight.type(dtype), scale.type(torch.float32) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # transform weight into 1/0/-1 (in fp32) if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): - data_torch = self.weight_quant(data_torch) + weight_torch, scale_torch = self.weight_quant(data_torch) - return [(self.map_tensor_name(name), data_torch)] + tensors: list[tuple[str, Tensor]] = [] + + if name.endswith("q_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_SCALE, bid), scale_torch)) + elif name.endswith("k_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_SCALE, bid), scale_torch)) + elif name.endswith("v_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V_SCALE, bid), scale_torch)) + elif name.endswith("o_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT_SCALE, bid), scale_torch)) + elif name.endswith("up_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SCALE, bid), scale_torch)) + elif name.endswith("down_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SCALE, bid), scale_torch)) + elif name.endswith("gate_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SCALE, bid), scale_torch)) + + if len(tensors) == 0: + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors @Model.register("GrokForCausalLM") diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 16cfd1717f7d6..05df330c0846f 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", }, + { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index 409fcf29eab5a..be88daa369e9d 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -137,6 +137,13 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP +#define QK2_2 32 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK2_2 / 4]; // nibbles / quants +} block_q2_2; +static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q4_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -1022,7 +1029,7 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() -GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256) +GGML_TABLE_BEGIN(uint32_t, q22_grid, 256) 0x00000000, 0x01000000, 0x00000000, 0xff000000, 0x00010000, 0x01010000, 0x00010000, 0xff010000, 0x00000000, 0x01000000, 0x00000000, 0xff000000, diff --git a/ggml-quants.c b/ggml-quants.c index 4b5209279a9fb..aebeb02170f0b 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,25 +659,44 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx -void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { - int8_t* dst = (int8_t*)y; - double min = 0.00001; - double max = min; - for (int i = 0; i < n; ++i) { - max = MAX(max, (double)fabs((double)x[i])); - } - float s = 127 / max; - act_scales[0] = s; - float temp; - for (int i = 0; i < n; ++i) { - temp = round((double)(x[i] * s)); - if (temp > 127) temp = 127; - if (temp < -128) temp = -128; - dst[i] = (int8_t)(temp); +void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) { + static const int qk = QK2_2; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + + const float d = 1.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/4; ++j) { + int8_t x0 = (int8_t)x[i*qk + j*4 + 0]; + int8_t x1 = (int8_t)x[i*qk + j*4 + 1]; + int8_t x2 = (int8_t)x[i*qk + j*4 + 2]; + int8_t x3 = (int8_t)x[i*qk + j*4 + 3]; + + const uint8_t xi0 = x0 >= 0 ? x0 : 3; + const uint8_t xi1 = x1 >= 0 ? x1 : 3; + const uint8_t xi2 = x2 >= 0 ? x2 : 3; + const uint8_t xi3 = x3 >= 0 ? x3 : 3; + + y[i].qs[j] = 0; + y[i].qs[j] |= (xi0 << 6); + y[i].qs[j] |= (xi1 << 4); + y[i].qs[j] |= (xi2 << 2); + y[i].qs[j] |= (xi3 << 0); + } } } // reference implementation for deterministic creation of model files +void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q2_2_reference(x, y, k); +} + void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3324,48 +3343,11 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } -size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - // 2 bits per weight - UNUSED(quant_weights); - - size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); - - int n = nrow * n_per_row; - - // f32 -> q8 - double max = 0; - for (int i = 0; i < n; ++i) { - max = MAX(max, (double)fabs((double)src[i])); - } - double i2_scale = max; - - uint8_t* q8 = (uint8_t*)dst; - for (int i=0; i 0 ? 1 : 3; - } - - // q8 -> 0, 1, 3 - // | | | - // 0, 1,-1 - - uint8_t* i2_weight = (uint8_t*)dst; - for (int i=0; ine[i] - 1)*tensor->nb[i]; } - if(tensor->type == GGML_TYPE_I2_S){ - nbytes = nbytes / 4 + 32; - } } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -12289,10 +12282,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( // 16 * 2, accounting for mmla kernels float tmp[32]; - // for per-tensor quant - const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); - const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); - for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { @@ -12325,12 +12314,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - if (src0->type == GGML_TYPE_I2_S) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); - } else { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { @@ -12494,13 +12478,8 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - if (src0->type == GGML_TYPE_I2_S) { - float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); - quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11); - } else { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; } } } @@ -14189,6 +14168,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14215,8 +14195,6 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: - case GGML_TYPE_I2_S: - case GGML_TYPE_I8_S: case GGML_TYPE_COUNT: { GGML_ASSERT(false); @@ -21340,6 +21318,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -21359,7 +21338,6 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21382,11 +21360,7 @@ size_t ggml_quantize_chunk( assert(false); } - if (type == GGML_TYPE_I2_S) { - result = nrows * row_size / 4 + 32; - } else { - GGML_ASSERT(result == nrows * row_size); - } + GGML_ASSERT(result == nrows * row_size); return result; } diff --git a/ggml.h b/ggml.h index c2e6859f53709..4ec555ccb39b2 100644 --- a/ggml.h +++ b/ggml.h @@ -377,8 +377,7 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_I2_S = 31, - GGML_TYPE_I8_S = 32, + GGML_TYPE_Q2_2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 78c7290d24487..7f2c10601f900 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,6 +202,13 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() ATTN_SUB_NORM = auto() + ATTN_Q_SCALE = auto() + ATTN_K_SCALE = auto() + ATTN_V_SCALE = auto() + ATTN_OUT_SCALE = auto() + FFN_UP_SCALE = auto() + FFN_DOWN_SCALE = auto() + FFN_GATE_SCALE = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -293,6 +300,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.ATTN_Q_SCALE: "blk.{bid}.attn_q_scale", + MODEL_TENSOR.ATTN_K_SCALE: "blk.{bid}.attn_k_scale", + MODEL_TENSOR.ATTN_V_SCALE: "blk.{bid}.attn_v_scale", + MODEL_TENSOR.ATTN_OUT_SCALE: "blk.{bid}.attn_output_scale", + MODEL_TENSOR.FFN_UP_SCALE: "blk.{bid}.ffn_up_scale", + MODEL_TENSOR.FFN_DOWN_SCALE: "blk.{bid}.ffn_down_scale", + MODEL_TENSOR.FFN_GATE_SCALE: "blk.{bid}.ffn_gate_scale", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -819,17 +833,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V, MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, + MODEL_TENSOR.ATTN_Q_SCALE, + MODEL_TENSOR.ATTN_K_SCALE, + MODEL_TENSOR.ATTN_V_SCALE, + MODEL_TENSOR.ATTN_OUT_SCALE, + MODEL_TENSOR.FFN_UP_SCALE, + MODEL_TENSOR.FFN_DOWN_SCALE, + MODEL_TENSOR.FFN_GATE_SCALE, ], # TODO } diff --git a/llama.cpp b/llama.cpp index 16ae07dd33ab6..28854e8cfb326 100644 --- a/llama.cpp +++ b/llama.cpp @@ -498,6 +498,13 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_ATTN_Q_SCALE, + LLM_TENSOR_ATTN_K_SCALE, + LLM_TENSOR_ATTN_V_SCALE, + LLM_TENSOR_ATTN_OUTPUT_SCALE, + LLM_TENSOR_FFN_UP_SCALE, + LLM_TENSOR_FFN_DOWN_SCALE, + LLM_TENSOR_FFN_GATE_SCALE, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1114,19 +1121,26 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_BITNET, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_ATTN_Q_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_K_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_V_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_OUTPUT_SCALE, "blk.%d.attn_output_scale" }, + { LLM_TENSOR_FFN_UP_SCALE, "blk.%d.ffn_up_scale" }, + { LLM_TENSOR_FFN_DOWN_SCALE, "blk.%d.ffn_down_scale" }, + { LLM_TENSOR_FFN_GATE_SCALE, "blk.%d.ffn_gate_scale" }, }, }, { @@ -2075,6 +2089,15 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -6460,16 +6483,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_SCALE, "weight", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_SCALE, "weight", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_SCALE, "weight", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUTPUT_SCALE, "weight", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SCALE, "weight", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SCALE, "weight", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SCALE, "weight", i), {1}); } } break; default: @@ -11545,6 +11575,7 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11553,6 +11584,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11561,6 +11593,7 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11659,6 +11692,7 @@ struct llm_build_context { ggml_build_forward_expand(graph, cur_attn); cur = ggml_mul_mat(ctx0, wo, cur_attn); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); } @@ -11681,10 +11715,12 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -11701,6 +11737,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -15444,6 +15481,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; @@ -15452,7 +15490,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; - case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/llama.h b/llama.h index f7cd33edc7444..7a2e0e31c1bb4 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_I2_S = 33, + LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; From 4edc958fec93b3b68ab713ca2f0fdc3687e2fa48 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Tue, 18 Jun 2024 22:16:16 +0800 Subject: [PATCH 23/32] fix code --- convert-hf-to-gguf.py | 14 +++++------ examples/quantize/quantize.cpp | 2 +- ggml-common.h | 2 +- gguf-py/gguf/constants.py | 21 ---------------- llama.cpp | 45 ++++++++++------------------------ 5 files changed, 22 insertions(+), 62 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 9a217c1c74f2d..0b19e470cfd7d 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1430,25 +1430,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("q_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch)) elif name.endswith("k_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch)) elif name.endswith("v_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch)) elif name.endswith("o_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch)) elif name.endswith("up_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch)) elif name.endswith("down_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch)) elif name.endswith("gate_proj.weight"): tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SCALE, bid), scale_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch)) if len(tensors) == 0: tensors.append((self.map_tensor_name(name), data_torch)) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 05df330c0846f..e3ebd660d6208 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, + { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.5 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index be88daa369e9d..d3b6d1a948db8 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -142,7 +142,7 @@ typedef struct { ggml_half d; // delta uint8_t qs[QK2_2 / 4]; // nibbles / quants } block_q2_2; -static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q2_2 block size/padding"); #define QK4_0 32 typedef struct { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7f2c10601f900..1fc8fcde5d80b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,13 +202,6 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() ATTN_SUB_NORM = auto() - ATTN_Q_SCALE = auto() - ATTN_K_SCALE = auto() - ATTN_V_SCALE = auto() - ATTN_OUT_SCALE = auto() - FFN_UP_SCALE = auto() - FFN_DOWN_SCALE = auto() - FFN_GATE_SCALE = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -300,13 +293,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", - MODEL_TENSOR.ATTN_Q_SCALE: "blk.{bid}.attn_q_scale", - MODEL_TENSOR.ATTN_K_SCALE: "blk.{bid}.attn_k_scale", - MODEL_TENSOR.ATTN_V_SCALE: "blk.{bid}.attn_v_scale", - MODEL_TENSOR.ATTN_OUT_SCALE: "blk.{bid}.attn_output_scale", - MODEL_TENSOR.FFN_UP_SCALE: "blk.{bid}.ffn_up_scale", - MODEL_TENSOR.FFN_DOWN_SCALE: "blk.{bid}.ffn_down_scale", - MODEL_TENSOR.FFN_GATE_SCALE: "blk.{bid}.ffn_gate_scale", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -841,13 +827,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, - MODEL_TENSOR.ATTN_Q_SCALE, - MODEL_TENSOR.ATTN_K_SCALE, - MODEL_TENSOR.ATTN_V_SCALE, - MODEL_TENSOR.ATTN_OUT_SCALE, - MODEL_TENSOR.FFN_UP_SCALE, - MODEL_TENSOR.FFN_DOWN_SCALE, - MODEL_TENSOR.FFN_GATE_SCALE, ], # TODO } diff --git a/llama.cpp b/llama.cpp index 28854e8cfb326..c87dd9c3cdd28 100644 --- a/llama.cpp +++ b/llama.cpp @@ -498,13 +498,6 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_FFN_SUB_NORM, - LLM_TENSOR_ATTN_Q_SCALE, - LLM_TENSOR_ATTN_K_SCALE, - LLM_TENSOR_ATTN_V_SCALE, - LLM_TENSOR_ATTN_OUTPUT_SCALE, - LLM_TENSOR_FFN_UP_SCALE, - LLM_TENSOR_FFN_DOWN_SCALE, - LLM_TENSOR_FFN_GATE_SCALE, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1134,13 +1127,6 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, - { LLM_TENSOR_ATTN_Q_SCALE, "blk.%d.attn_q_scale" }, - { LLM_TENSOR_ATTN_K_SCALE, "blk.%d.attn_q_scale" }, - { LLM_TENSOR_ATTN_V_SCALE, "blk.%d.attn_q_scale" }, - { LLM_TENSOR_ATTN_OUTPUT_SCALE, "blk.%d.attn_output_scale" }, - { LLM_TENSOR_FFN_UP_SCALE, "blk.%d.ffn_up_scale" }, - { LLM_TENSOR_FFN_DOWN_SCALE, "blk.%d.ffn_down_scale" }, - { LLM_TENSOR_FFN_GATE_SCALE, "blk.%d.ffn_gate_scale" }, }, }, { @@ -6483,23 +6469,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_SCALE, "weight", i), {1}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_SCALE, "weight", i), {1}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_SCALE, "weight", i), {1}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUTPUT_SCALE, "weight", i), {1}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SCALE, "weight", i), {1}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SCALE, "weight", i), {1}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SCALE, "weight", i), {1}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; default: @@ -11624,14 +11610,9 @@ struct llm_build_context { const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - struct ggml_tensor * q_cur = Qcur; - struct ggml_tensor * kq_mask = KQ_mask; float kq_scale = 1.0f/sqrtf(float(n_embd_head)); - struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; - struct ggml_cgraph * graph = gf; - struct ggml_tensor * wo = model.layers[il].wo; struct ggml_tensor * cur_attn; - struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); cb(q, "q", il); struct ggml_tensor * k = @@ -11653,14 +11634,14 @@ struct llm_build_context { 0); cb(v, "v", il); - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); GGML_ASSERT(kv_self.size == n_ctx); @@ -11685,13 +11666,13 @@ struct llm_build_context { } cur_attn = llm_build_norm(ctx0, cur_attn, hparams, - attn_sub_norm, NULL, + model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur_attn, "attn_sub_norm", il); - ggml_build_forward_expand(graph, cur_attn); + ggml_build_forward_expand(gf, cur_attn); - cur = ggml_mul_mat(ctx0, wo, cur_attn); + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); From 89c7e4c1dd1b0a71ffd40fcee9143962fb85722d Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Tue, 18 Jun 2024 23:33:58 +0800 Subject: [PATCH 24/32] remove block scale --- examples/quantize/quantize.cpp | 2 +- ggml-common.h | 3 +-- ggml-quants.c | 9 +-------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index e3ebd660d6208..05df330c0846f 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.5 bpw quantization", }, + { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index d3b6d1a948db8..a1a8246656cca 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -139,10 +139,9 @@ typedef sycl::half2 ggml_half2; #define QK2_2 32 typedef struct { - ggml_half d; // delta uint8_t qs[QK2_2 / 4]; // nibbles / quants } block_q2_2; -static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q2_2 block size/padding"); +static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding"); #define QK4_0 32 typedef struct { diff --git a/ggml-quants.c b/ggml-quants.c index aebeb02170f0b..a3c8c67319557 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -668,10 +668,6 @@ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict for (int i = 0; i < nb; i++) { - const float d = 1.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - for (int j = 0; j < qk/4; ++j) { int8_t x0 = (int8_t)x[i*qk + j*4 + 0]; int8_t x1 = (int8_t)x[i*qk + j*4 + 1]; @@ -14369,10 +14365,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; - case GGML_TYPE_Q2_2: - { - VALIDATE_ROW_DATA_D_F16_IMPL(block_q2_2, data, nb); - } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); @@ -14467,6 +14459,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q2_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: From fcf2da4621dc0f7079ca318a48ca961793ab9e4c Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Wed, 19 Jun 2024 21:48:04 +0800 Subject: [PATCH 25/32] add dequantize --- convert-hf-to-gguf.py | 49 +++++++++++++-------------------------- ggml-quants.c | 20 ++++++++++++++++ ggml-quants.h | 1 + ggml.c | 1 + gguf-py/gguf/constants.py | 3 +++ llama.cpp | 2 +- 6 files changed, 42 insertions(+), 34 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0b19e470cfd7d..22456990792c3 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1420,40 +1420,23 @@ def weight_quant(self, weight): return weight.type(dtype), scale.type(torch.float32) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # transform weight into 1/0/-1 (in fp32) - if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", - "down_proj.weight", "up_proj.weight", "gate_proj.weight", - "o_proj.weight")): - weight_torch, scale_torch = self.weight_quant(data_torch) - - tensors: list[tuple[str, Tensor]] = [] - - if name.endswith("q_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch)) - elif name.endswith("k_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch)) - elif name.endswith("v_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch)) - elif name.endswith("o_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch)) - elif name.endswith("up_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch)) - elif name.endswith("down_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch)) - elif name.endswith("gate_proj.weight"): - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch)) - - if len(tensors) == 0: - tensors.append((self.map_tensor_name(name), data_torch)) + new_name = self.map_tensor_name(name) - return tensors + if any(self.match_model_tensor_name(new_name, key, bid) for key in [ + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.FFN_UP, + gguf.MODEL_TENSOR.FFN_DOWN, + gguf.MODEL_TENSOR.FFN_GATE, + ]): + # transform weight into 1/0/-1 (in fp32) + weight_torch, scale_torch = self.weight_quant(data_torch) + yield (new_name, weight_torch) + yield (new_name.removesuffix(".weight") + ".scale", scale_torch) + else: + yield (new_name, data_torch) @Model.register("GrokForCausalLM") diff --git a/ggml-quants.c b/ggml-quants.c index a3c8c67319557..a3633fc53afee 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1545,6 +1545,26 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #endif } +void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK2_2; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + + for (int j = 0; j < qk/4; ++j) { + const int8_t * q = (const int8_t *) (q22_grid + x[i].qs[j]); + + *y++ = (float) q[0]; + *y++ = (float) q[1]; + *y++ = (float) q[2]; + *y++ = (float) q[3]; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; diff --git a/ggml-quants.h b/ggml-quants.h index e5ef8a8ca1f06..e159cef5f71b7 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -55,6 +55,7 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization +void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/ggml.c b/ggml.c index d714171f7dc2b..55effd7178323 100644 --- a/ggml.c +++ b/ggml.c @@ -620,6 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_name = "q2_2", .blck_size = QK2_2, .type_size = sizeof(block_q2_2), + .to_float = (ggml_to_float_t) dequantize_row_q2_2, .is_quantized = true, .from_float = quantize_row_q2_2, .from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1fc8fcde5d80b..301200869b6f0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -923,6 +923,7 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 + Q2_2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -964,6 +965,7 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_Q2_2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1010,6 +1012,7 @@ def get_type(val: Any) -> GGUFValueType: GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q2_2: (32, 8), GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), diff --git a/llama.cpp b/llama.cpp index c87dd9c3cdd28..85182f4bbeea1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3885,6 +3885,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -11705,7 +11706,6 @@ struct llm_build_context { cb(cur, "ffn_gate", il); - cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); From fa9a742b46e127e1a23ceedcb97b33b330528234 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Wed, 19 Jun 2024 21:49:13 +0800 Subject: [PATCH 26/32] fix seq --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 55effd7178323..303b2f5636147 100644 --- a/ggml.c +++ b/ggml.c @@ -620,8 +620,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_name = "q2_2", .blck_size = QK2_2, .type_size = sizeof(block_q2_2), - .to_float = (ggml_to_float_t) dequantize_row_q2_2, .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_2, .from_float = quantize_row_q2_2, .from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference, .vec_dot = ggml_vec_dot_q2_2_q8_0, From 230396bc5bb53c058b3b6d86af2bb401a3837c84 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 20 Jun 2024 00:12:58 +0800 Subject: [PATCH 27/32] update avx2 --- ggml-quants.c | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index a3633fc53afee..f45ece1f25836 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3806,18 +3806,27 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) ); - __m256i xq8 = _mm256_set_epi32( - (int)q22_grid[x[i].qs[7]], - (int)q22_grid[x[i].qs[6]], - (int)q22_grid[x[i].qs[5]], - (int)q22_grid[x[i].qs[4]], - (int)q22_grid[x[i].qs[3]], - (int)q22_grid[x[i].qs[2]], - (int)q22_grid[x[i].qs[1]], - (int)q22_grid[x[i].qs[0]] - ); - - __m256i yq8 = _mm256_loadu_si256((const __m256i*)(y[i].qs)); + __m128i xq8b = _mm_loadu_si64(x[i].qs); + __m256i xq8 = MM256_SET_M128I(xq8b, xq8b); + __m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, + 4, -1, 4, -1, 4, -1, 4, -1, + 1, -1, 1, -1, 1, -1, 1, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); + __m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, + 6, -1, 6, -1, 6, -1, 6, -1, + 3, -1, 3, -1, 3, -1, 3, -1, + 2, -1, 2, -1, 2, -1, 2, -1)); + __m256i shift = _mm256_set_epi16(64, 16, 4, 1, + 64, 16, 4, 1, + 64, 16, 4, 1, + 64, 16, 4, 1); + xq8l = _mm256_mullo_epi16(xq8l, shift); + xq8h = _mm256_mullo_epi16(xq8h, shift); + xq8l = _mm256_srai_epi16(xq8l, 14); + xq8h = _mm256_srai_epi16(xq8h, 14); + xq8 = _mm256_packs_epi16(xq8l, xq8h); + + __m256i yq8 = _mm256_lddqu_si256((const __m256i*)(y[i].qs)); const __m256 q = mul_sum_i8_pairs_float(xq8, yq8); acc = _mm256_fmadd_ps( d, q, acc ); From 2b097682e07f4c7f830f197d1b803e7155a632a1 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 20 Jun 2024 20:06:13 +0800 Subject: [PATCH 28/32] remove q2_2 --- examples/quantize/quantize.cpp | 1 - ggml-common.h | 6 -- ggml-quants.c | 127 --------------------------------- ggml-quants.h | 5 -- ggml.c | 14 ---- ggml.h | 1 - gguf-py/gguf/constants.py | 3 - llama.cpp | 2 - llama.h | 1 - 9 files changed, 160 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 05df330c0846f..28584e14b788c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,6 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index a1a8246656cca..d3fa51235c1f7 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -137,12 +137,6 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP -#define QK2_2 32 -typedef struct { - uint8_t qs[QK2_2 / 4]; // nibbles / quants -} block_q2_2; -static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding"); - #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml-quants.c b/ggml-quants.c index f45ece1f25836..ee86fd6b99db7 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,40 +659,6 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx -void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - int8_t x0 = (int8_t)x[i*qk + j*4 + 0]; - int8_t x1 = (int8_t)x[i*qk + j*4 + 1]; - int8_t x2 = (int8_t)x[i*qk + j*4 + 2]; - int8_t x3 = (int8_t)x[i*qk + j*4 + 3]; - - const uint8_t xi0 = x0 >= 0 ? x0 : 3; - const uint8_t xi1 = x1 >= 0 ? x1 : 3; - const uint8_t xi2 = x2 >= 0 ? x2 : 3; - const uint8_t xi3 = x3 >= 0 ? x3 : 3; - - y[i].qs[j] = 0; - y[i].qs[j] |= (xi0 << 6); - y[i].qs[j] |= (xi1 << 4); - y[i].qs[j] |= (xi2 << 2); - y[i].qs[j] |= (xi3 << 0); - } - } -} - -// reference implementation for deterministic creation of model files -void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q2_2_reference(x, y, k); -} - void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -1545,26 +1511,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #endif } -void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - const int8_t * q = (const int8_t *) (q22_grid + x[i].qs[j]); - - *y++ = (float) q[0]; - *y++ = (float) q[1]; - *y++ = (float) q[2]; - *y++ = (float) q[3]; - } - } -} - void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3359,13 +3305,6 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } -size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - (void)quant_weights; // not used - const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row); - quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row); - return nrow * row_size; -} - // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -3786,71 +3725,6 @@ static inline __m128i get_scale_shuffle(int i) { } #endif -void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_2 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__AVX2__) - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) ); - - __m128i xq8b = _mm_loadu_si64(x[i].qs); - __m256i xq8 = MM256_SET_M128I(xq8b, xq8b); - __m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, - 4, -1, 4, -1, 4, -1, 4, -1, - 1, -1, 1, -1, 1, -1, 1, -1, - 0, -1, 0, -1, 0, -1, 0, -1)); - __m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, - 6, -1, 6, -1, 6, -1, 6, -1, - 3, -1, 3, -1, 3, -1, 3, -1, - 2, -1, 2, -1, 2, -1, 2, -1)); - __m256i shift = _mm256_set_epi16(64, 16, 4, 1, - 64, 16, 4, 1, - 64, 16, 4, 1, - 64, 16, 4, 1); - xq8l = _mm256_mullo_epi16(xq8l, shift); - xq8h = _mm256_mullo_epi16(xq8h, shift); - xq8l = _mm256_srai_epi16(xq8l, 14); - xq8h = _mm256_srai_epi16(xq8h, 14); - xq8 = _mm256_packs_epi16(xq8l, xq8h); - - __m256i yq8 = _mm256_lddqu_si256((const __m256i*)(y[i].qs)); - const __m256 q = mul_sum_i8_pairs_float(xq8, yq8); - - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#else - - float sumf = 0.0; - for (int i = 0; i < nb; i++) { - int sumi = 0; - for (int j = 0; j < qk / 4; j++) { - const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]); - sumi += (int)y[i].qs[4*j+0] * weight[0]; - sumi += (int)y[i].qs[4*j+1] * weight[1]; - sumi += (int)y[i].qs[4*j+2] * weight[2]; - sumi += (int)y[i].qs[4*j+3] * weight[3]; - } - sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d)); - } - *s = sumf; -#endif -} - void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -14488,7 +14362,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - case GGML_TYPE_Q2_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml-quants.h b/ggml-quants.h index e159cef5f71b7..4d436a8f06b3e 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -12,7 +12,6 @@ extern "C" { #endif // Quantization -void quantize_row_q2_2_reference(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -33,7 +32,6 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -55,7 +53,6 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -81,7 +78,6 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product -void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -120,7 +116,6 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index 303b2f5636147..1fc77743bc7b9 100644 --- a/ggml.c +++ b/ggml.c @@ -616,18 +616,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, - [GGML_TYPE_Q2_2] = { - .type_name = "q2_2", - .blck_size = QK2_2, - .type_size = sizeof(block_q2_2), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q2_2, - .from_float = quantize_row_q2_2, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference, - .vec_dot = ggml_vec_dot_q2_2_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -14169,7 +14157,6 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: - case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -21319,7 +21306,6 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml.h b/ggml.h index 4ec555ccb39b2..13502a3622fc4 100644 --- a/ggml.h +++ b/ggml.h @@ -377,7 +377,6 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_Q2_2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 301200869b6f0..1fc8fcde5d80b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -923,7 +923,6 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 - Q2_2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -965,7 +964,6 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q2_2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1012,7 +1010,6 @@ def get_type(val: Any) -> GGUFValueType: GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q2_2: (32, 8), GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), diff --git a/llama.cpp b/llama.cpp index 85182f4bbeea1..1622bc8d3f4f1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3885,7 +3885,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -15462,7 +15461,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; diff --git a/llama.h b/llama.h index 7a2e0e31c1bb4..62908261f2791 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,6 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; From a58cf0d61fd52674cd6d4b8e94df993d63fecaf4 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 20 Jun 2024 20:08:10 +0800 Subject: [PATCH 29/32] remove q22_grid --- ggml-common.h | 67 --------------------------------------------------- ggml-quants.c | 1 + 2 files changed, 1 insertion(+), 67 deletions(-) diff --git a/ggml-common.h b/ggml-common.h index d3fa51235c1f7..e8efceb760d40 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1022,73 +1022,6 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() -GGML_TABLE_BEGIN(uint32_t, q22_grid, 256) - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00010000, 0x01010000, 0x00010000, 0xff010000, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, - 0x00000100, 0x01000100, 0x00000100, 0xff000100, - 0x00010100, 0x01010100, 0x00010100, 0xff010100, - 0x00000100, 0x01000100, 0x00000100, 0xff000100, - 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00010000, 0x01010000, 0x00010000, 0xff010000, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, - 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, - 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, - 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, - 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, - 0x00000001, 0x01000001, 0x00000001, 0xff000001, - 0x00010001, 0x01010001, 0x00010001, 0xff010001, - 0x00000001, 0x01000001, 0x00000001, 0xff000001, - 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, - 0x00000101, 0x01000101, 0x00000101, 0xff000101, - 0x00010101, 0x01010101, 0x00010101, 0xff010101, - 0x00000101, 0x01000101, 0x00000101, 0xff000101, - 0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, - 0x00000001, 0x01000001, 0x00000001, 0xff000001, - 0x00010001, 0x01010001, 0x00010001, 0xff010001, - 0x00000001, 0x01000001, 0x00000001, 0xff000001, - 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, - 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, - 0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, - 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, - 0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00010000, 0x01010000, 0x00010000, 0xff010000, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, - 0x00000100, 0x01000100, 0x00000100, 0xff000100, - 0x00010100, 0x01010100, 0x00010100, 0xff010100, - 0x00000100, 0x01000100, 0x00000100, 0xff000100, - 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00010000, 0x01010000, 0x00010000, 0xff010000, - 0x00000000, 0x01000000, 0x00000000, 0xff000000, - 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, - 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, - 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, - 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, - 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, - 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, - 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, - 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, - 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, - 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, - 0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, - 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, - 0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, - 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, - 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, - 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, - 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, - 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, - 0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, - 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, - 0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, -GGML_TABLE_END() - #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml-quants.c b/ggml-quants.c index ee86fd6b99db7..9f864e5c479ea 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,6 +659,7 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx +// reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; From c6ddfa7e37a707ba0502f0456427c91a79cfddf8 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Thu, 20 Jun 2024 22:41:29 +0800 Subject: [PATCH 30/32] fix whitespace --- llama.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2453a29434e4f..5d0bedf6cd894 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11932,12 +11932,10 @@ struct llm_build_context { struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); - cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); - cb(cur, "ffn_gate", il); cur = ggml_silu(ctx0, cur); From 55a57a5063e16c67721263a12fdb7497a4521aba Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Fri, 21 Jun 2024 16:12:48 +0800 Subject: [PATCH 31/32] reuse llm_build_kv --- llama.cpp | 100 ++++++++++-------------------------------------------- 1 file changed, 18 insertions(+), 82 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5d0bedf6cd894..de55be65e4c8f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6701,9 +6701,6 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); } - const uint32_t n_ff = hparams.n_ff; - model.layers.resize(n_layer); - for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -6714,23 +6711,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); + layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); + layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); + layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); + layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); + layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); + layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); + layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; default: @@ -7373,7 +7370,10 @@ static struct ggml_tensor * llm_build_kqv( ggml_build_forward_expand(graph, cur); - cur = ggml_mul_mat(ctx, wo, cur); + if (wo) { + cur = ggml_mul_mat(ctx, wo, cur); + } + if (wo_b) { cb(cur, "kqv_wo", il); } @@ -11835,82 +11835,18 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_head_v = hparams.n_embd_head_v; - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - float kq_scale = 1.0f/sqrtf(float(n_embd_head)); - struct ggml_tensor * cur_attn; - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - cb(q, "q", il); - - struct ggml_tensor * k = - ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - 0); - cb(k, "k", il); - - if (cparams.flash_attn) { - - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), - 0); - cb(v, "v", il); - - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - - cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - GGML_ASSERT(kv_self.size == n_ctx); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur_attn, "kqv_merged_cont", il); - } + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + nullptr, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - cur_attn = llm_build_norm(ctx0, cur_attn, hparams, + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur_attn, "attn_sub_norm", il); - - ggml_build_forward_expand(gf, cur_attn); + cb(cur, "attn_sub_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); - - cb(cur, "kqv_out", il); + cb(cur, "attn_o_out", il); } if (il == n_layer - 1) { From 226c5eed4e0a19943c52aaef924f203d2f9a08cc Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Sun, 23 Jun 2024 15:58:30 +0000 Subject: [PATCH 32/32] fix bo --- llama.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 049bc81dd79c9..c710ef82b746e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11861,7 +11861,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - nullptr, model.layers[il].bo, + nullptr, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cur = llm_build_norm(ctx0, cur, hparams, @@ -11871,6 +11871,9 @@ struct llm_build_context { cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } cb(cur, "attn_o_out", il); }