Skip to content

Commit

Permalink
fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
Eddie-Wang1120 committed Jun 18, 2024
1 parent a03eff3 commit 4edc958
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 62 deletions.
14 changes: 7 additions & 7 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,25 +1430,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

if name.endswith("q_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch))
elif name.endswith("k_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch))
elif name.endswith("v_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch))
elif name.endswith("o_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch))
elif name.endswith("up_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch))
elif name.endswith("down_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch))
elif name.endswith("gate_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SCALE, bid), scale_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch))

if len(tensors) == 0:
tensors.append((self.map_tensor_name(name), data_torch))
Expand Down
2 changes: 1 addition & 1 deletion examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.5 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
Expand Down
2 changes: 1 addition & 1 deletion ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ typedef struct {
ggml_half d; // delta
uint8_t qs[QK2_2 / 4]; // nibbles / quants
} block_q2_2;
static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q4_0 block size/padding");
static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q2_2 block size/padding");

#define QK4_0 32
typedef struct {
Expand Down
21 changes: 0 additions & 21 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,6 @@ class MODEL_TENSOR(IntEnum):
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
ATTN_SUB_NORM = auto()
ATTN_Q_SCALE = auto()
ATTN_K_SCALE = auto()
ATTN_V_SCALE = auto()
ATTN_OUT_SCALE = auto()
FFN_UP_SCALE = auto()
FFN_DOWN_SCALE = auto()
FFN_GATE_SCALE = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -300,13 +293,6 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
MODEL_TENSOR.ATTN_Q_SCALE: "blk.{bid}.attn_q_scale",
MODEL_TENSOR.ATTN_K_SCALE: "blk.{bid}.attn_k_scale",
MODEL_TENSOR.ATTN_V_SCALE: "blk.{bid}.attn_v_scale",
MODEL_TENSOR.ATTN_OUT_SCALE: "blk.{bid}.attn_output_scale",
MODEL_TENSOR.FFN_UP_SCALE: "blk.{bid}.ffn_up_scale",
MODEL_TENSOR.FFN_DOWN_SCALE: "blk.{bid}.ffn_down_scale",
MODEL_TENSOR.FFN_GATE_SCALE: "blk.{bid}.ffn_gate_scale",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -841,13 +827,6 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_SUB_NORM,
MODEL_TENSOR.FFN_SUB_NORM,
MODEL_TENSOR.ATTN_Q_SCALE,
MODEL_TENSOR.ATTN_K_SCALE,
MODEL_TENSOR.ATTN_V_SCALE,
MODEL_TENSOR.ATTN_OUT_SCALE,
MODEL_TENSOR.FFN_UP_SCALE,
MODEL_TENSOR.FFN_DOWN_SCALE,
MODEL_TENSOR.FFN_GATE_SCALE,
],
# TODO
}
Expand Down
45 changes: 13 additions & 32 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_ATTN_Q_SCALE,
LLM_TENSOR_ATTN_K_SCALE,
LLM_TENSOR_ATTN_V_SCALE,
LLM_TENSOR_ATTN_OUTPUT_SCALE,
LLM_TENSOR_FFN_UP_SCALE,
LLM_TENSOR_FFN_DOWN_SCALE,
LLM_TENSOR_FFN_GATE_SCALE,
};

static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
Expand Down Expand Up @@ -1134,13 +1127,6 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
{ LLM_TENSOR_ATTN_Q_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_K_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_V_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_OUTPUT_SCALE, "blk.%d.attn_output_scale" },
{ LLM_TENSOR_FFN_UP_SCALE, "blk.%d.ffn_up_scale" },
{ LLM_TENSOR_FFN_DOWN_SCALE, "blk.%d.ffn_down_scale" },
{ LLM_TENSOR_FFN_GATE_SCALE, "blk.%d.ffn_gate_scale" },
},
},
{
Expand Down Expand Up @@ -6483,23 +6469,23 @@ static bool llm_load_tensors(
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});

layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_SCALE, "weight", i), {1});
layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_SCALE, "weight", i), {1});
layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_SCALE, "weight", i), {1});
layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUTPUT_SCALE, "weight", i), {1});
layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1});

layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});

layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SCALE, "weight", i), {1});
layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SCALE, "weight", i), {1});
layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SCALE, "weight", i), {1});
layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
}
} break;
default:
Expand Down Expand Up @@ -11624,14 +11610,9 @@ struct llm_build_context {
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();

struct ggml_tensor * q_cur = Qcur;
struct ggml_tensor * kq_mask = KQ_mask;
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
struct ggml_cgraph * graph = gf;
struct ggml_tensor * wo = model.layers[il].wo;
struct ggml_tensor * cur_attn;
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il);

struct ggml_tensor * k =
Expand All @@ -11653,14 +11634,14 @@ struct llm_build_context {
0);
cb(v, "v", il);

cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);

cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);

kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

GGML_ASSERT(kv_self.size == n_ctx);
Expand All @@ -11685,13 +11666,13 @@ struct llm_build_context {
}

cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
attn_sub_norm, NULL,
model.layers[il].attn_sub_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur_attn, "attn_sub_norm", il);

ggml_build_forward_expand(graph, cur_attn);
ggml_build_forward_expand(gf, cur_attn);

cur = ggml_mul_mat(ctx0, wo, cur_attn);
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn);
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);

cb(cur, "kqv_out", il);
Expand Down

0 comments on commit 4edc958

Please sign in to comment.