Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize: specify each major tensor quant in CLI for common LLMs #8917

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

Nexesenex
Copy link
Contributor

@Nexesenex Nexesenex commented Aug 7, 2024

This PR simply replicates the tensor per tensor custom quantization CLI feature brought by Ikawrakow for the token embeddings and output tensors in #6239 to :

  • attn_q.weight
  • attn_k.weight
  • attn_v.weight
  • attn_qkv.weight
  • attn_output.weight
  • ffn_gate
  • ffn_down
  • ffn_up

This, to allow LlamaCPP users to easily tailor their chosen quant strategy to their needs, but ALSO to allow them to requant easily a quant "a bit too big" for their VRAM in the case of GPU users.

For example, a nice Miqu 70b Q5_K_M (which has no FP16 weight available beyond dequants of Q5_K_M) can be short of VRAM for high context in one's pair of RTX 3090s. And if one is French, like me, then Miqu is one of his main local models.

Requanting the Q5_K_M in... Q5_K_M, BUT with all the ffn_down and attn_v.weight tensors specified in Q5_K, and the attn_q.weight specified in Q4_K_M might save you approximately 1.5GB without degrading too much the quality. That means 1.3-1.4GB of additional context (yummy with FA and quantized KV Cache) and let's say 100-200MB of additional compute cache with a reasonable Blas Batch Size in MMQ.

But also : the unspecified tensors won't be requantized, because LlamaCPP just copy the tensor rather than requantizing it when a specific tensor quant of the chosen strategy is the same than the source. So one can enjoy the original Miqu quant of these tensors rather than a dequant/requant.

And that's just an example.

I think that many LCPP users could enjoy this feature for their own needs.

This, even if it remains quite basic :
This PR doesn't support hybrid quantization of a tensor (example, with a fraction of the layers in the upper quant (from layer 0 onwards), or the "use_more_bits" calculus devised by Ikawrakow to create intervals of different quants (ex : 1 layer every 3 layers quantized with the superior quant).

CL example: llama-quantize --allow-requantize --imatrix Q:\iMatrix\Sheared\princeton-nlp_Sheared-LLaMA-2.7B-AR-b1924-Q8_0.iMatrix_Wiki_c32_ch500.dat --output-tensor-type q4_0 --token-embedding-type q4_0 --attn-q-type q4_0 --attn-k-type q4_0 --attn-v-type q4_0 --attn-output-type q4_0 --ffn-gate-type q4_0 --ffn-down-type q4_0 --ffn-up-type q4_0 D:\text-generation-webui\models\Q8_0\princeton-nlp_Sheared-LLaMA-2.7B-AR-b1924-Q8_0.gguf D:\text-generation-webui\models\princeton-nlp_Sheared-LLaMA-2.7B-AR-b228N.iMatrix_Wiki_c32_ch500-Q5_K_M.gguf Q5_K_M for a full q4_0 quant equivalent to a pure quant, but specified tensor by tensor.

Edit : Accordingly to Slaren's remarks, I created a Custom Quantization Scheme (CQS), now required to the specific tensor quantization to be used, this without risking to alter the regular quantization strategies and mess up the users who'd download unknowingly such edited quants.

  • I have read the contributing guidelines
  • Self-reported review complexity:
    • Low
    • Medium -> Test all the arguments in command line on a small model by quantizing it with llama-quantize, and combine with other relevant commands to see if nothing gets broken.
    • High

This PR simply replicates the tensor per tensor custom quantization CLI feature brought by Ikawrakow for the token embeddings and output tensors in ggerganov#6239 to :
- attn_q.weight
- attn_k.weight
- attn_v.weight
- attn_qkv.weight
- attn_output.weight
- ffn_gate
- ffn_down
- ffn_up

This, to allow LlamaCPP users to easily tailor their chosen quant strategy to their needs, but ALSO to allow them to requant easily a quant "a bit too big" for their VRAM in the case of GPU users.

For example, a nice Miqu 70b Q5_K_M (which has no FP16 weight available beyond dequants of Q5_K_M) is short of VRAM in one's pair of 3090s.
And one is French, like me, so Miqu is one of his main local model.

Requanting the Q5_K_M in... Q5_K_M, BUT with all the ffn_down and attn_v.weight tensors specified in Q5_K, and the attn_q.weight specified in Q4_K_M might save you approximatively 1.5GB without degrading too much the quality.
That means 1.3-1.4GB of additional context (yummy with FA and KV Cache) and let's say 100-200MB of additional compute cache with a resonable Blas Batch Size in MMQ.

But also : the unspecified tensors won't be requantized, because LlamaCPP just copy the tensor rather than requantizing it when a specific tensor quant of the chosent strategy is the same than the source.
So one can enjoy the original Miqu quant of these tensors rather than a dequant/requant.

And that's just an example.

I think that many LCPP users could enjoy this feature for their own needs.

This, even if it remains quite basic :
This PR doesn't support hybrid quantization of a tensor (example, with a fraction of the layers in the upper quant (from layer 0 onwards), or the "more_bits" calculus devised by Ikawrakow to create intervals of different quants (ex : 1 layer every 3 layers quantized with the superior quant).

CL example: `llama-quantize --allow-requantize --imatrix Q:\iMatrix\Sheared\princeton-nlp_Sheared-LLaMA-2.7B-AR-b1924-Q8_0.iMatrix_Wiki_c32_ch500.dat --output-tensor-type q4_0 --token-embedding-type q4_0 --attn-q-type q4_0 --attn-k-type q4_0 --attn-v-type q4_0 --attn-output-type q4_0 --ffn-gate-type q4_0 --ffn-down-type q4_0 --ffn-up-type q4_0 D:\text-generation-webui\models\Q8_0\princeton-nlp_Sheared-LLaMA-2.7B-AR-b1924-Q8_0.gguf D:\text-generation-webui\models\princeton-nlp_Sheared-LLaMA-2.7B-AR-b228N.iMatrix_Wiki_c32_ch500-Q5_K_M.gguf Q5_K_M` for a full q4_0 quant equivalent to a pure quant, but specified tensor by tensor.
@CISC
Copy link
Contributor

CISC commented Aug 7, 2024

Doesn't this sort of supersede #6844?

@Nexesenex
Copy link
Contributor Author

Nexesenex commented Aug 7, 2024

Doesn't this sort of supersede #6844?

I wouldn't claim to supersede anything, #6844 is more complex and use external config files, while my PR simply replicates exactly the CL mecanism already used for output.weight and token_embd.weight to the 7 other usual major tensors found in common LLMs, and attn_qkv.weight for some models having a monolithic QKV head.

And my PR might be ready to merge, if no one finds a glitch. It works well for me.

@CISC
Copy link
Contributor

CISC commented Aug 7, 2024

Sure, just fix the EditorConfig failures first. :)

But yeah, this one is also less likely to cause issues.

@Nexesenex
Copy link
Contributor Author

Sure, just fix the EditorConfig failures first. :)

I just saw, and I fixed the bad indent and trailing whitespaces. :)

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Aug 8, 2024
@jukofyork
Copy link
Contributor

jukofyork commented Aug 8, 2024

The deepseek-v2 models have a couple of other tensor names for the low-rank tensors:

    "model.layers.0.self_attn.q_a_proj.weight": "model-00001-of-000055.safetensors",
    "model.layers.0.self_attn.q_b_proj.weight": "model-00001-of-000055.safetensors",
    "model.layers.0.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-000055.safetensors",
    "model.layers.0.self_attn.kv_b_proj.weight": "model-00001-of-000055.safetensors",

see: https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/raw/main/model.safetensors.index.json

The "lite" models look to just use kv_a_proj_with_mqa and kv_b_proj:

    "model.layers.0.self_attn.q_proj.weight": "model-00001-of-000004.safetensors",
    "model.layers.0.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-000004.safetensors",
    "model.layers.0.self_attn.kv_b_proj.weight": "model-00001-of-000004.safetensors",

see: https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct/raw/main/model.safetensors.index.json

@jukofyork
Copy link
Contributor

jukofyork commented Aug 8, 2024

I'm not sure if llama.cpp renames them, but the mixtral-moe architecture also use w1, w2, and w3:

    "model.embed_tokens.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.0.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.0.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.0.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.1.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.1.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.1.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.2.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.2.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.2.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.3.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.3.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.3.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.4.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.4.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.4.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.5.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.5.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.5.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.6.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.6.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.6.w3.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.7.w1.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.7.w2.weight": "model-00001-of-00019.safetensors",
    "model.layers.0.block_sparse_moe.experts.7.w3.weight": "model-00001-of-00019.safetensors",

see: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/model.safetensors.index.json

IIRC, the block_sparse_moe.gate.weight is left unquantised in llama.cpp though.


Might also be worth checking the qwen-moe and dbrx models as they too use different names and the qwen-moe models (and deepseek-v2 models) also use a mlp.shared_expert that gets fired every token and thus would make sense to use a higher quant for.

@Nexesenex
Copy link
Contributor Author

@jukofyork : this is quite an undertaking for me, for if the tensors of the models you mention are not effected by my PR, it means that they are also out of the regular "tensor by tensor" quantization rules set by Ikawrakow to vary from the base GGML_TYPE (ex : Q5_K) of the FTYPE/quant strategy (ex : Q5_K_M).

So, they are also currently out of the scope of my PR, which includes a condition for the tensors with existing specific quantization rules, as was done for token_embd and output.weight, without pushing new set of tensor quantization rules on which such conditionality can also be applied : the tensors you mention most probably rely on the base GGML_type used for a given quantization strategy, Ikawrakow not setting specific rules for them beyond that.

So, I think that the tensors you mention should indeed be added both in the regular "tensor by tensor" quant strategies rules if they are not atm, and in the specific tensor quantization via CLI argument (like in this PR), because all LLMs should in my opinion benefit from both "features".

But this, in a second PR, after the maintainers of Llama.cpp give their imput and potential green light over such an evolution over Ikawrakow's work, and of course, over this present PR itself.

@jukofyork
Copy link
Contributor

jukofyork commented Aug 8, 2024

@jukofyork : this is quite an undertaking for me, for if the tensors of the models you mention are not effected by my PR, it means that they are also out of the regular "tensor by tensor" quantization rules set by Ikawrakow to vary from the base GGML_TYPE (ex : Q5_K) of the FTYPE/quant strategy (ex : Q5_K_M).

So, they are also currently out of the scope of my PR, which includes a condition for the tensors with existing specific quantization rules, as was done for token_embd and output.weight, without pushing new set of tensor quantization rules on which such conditionality can also be applied : the tensors you mention most probably rely on the base GGML_type used for a given quantization strategy, Ikawrakow not setting specific rules for them beyond that.

So, I think that the tensors you mention should indeed be added both in the regular "tensor by tensor" quant strategies rules if they are not atm, and in the specific tensor quantization via CLI argument (like in this PR), because all LLMs should in my opinion benefit from both "features".

But this, in a second PR, after the maintainers of Llama.cpp give their imput and potential green light over such an evolution over Ikawrakow's work, and of course, over this present PR itself.

Yeah, the problem is related to the llama_tensor_get_type function not being updated for a long time:

static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
    const std::string name = ggml_get_name(tensor);

    // TODO: avoid hardcoded tensor names - use the TN_* constants
    const llm_arch arch = qs.model.arch;
    const auto       tn = LLM_TN(arch);

    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
    };
    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
        if (n_expert > 1) {
            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but iccasionally randomly
            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
            // for getting the current layer as I initially thought, and we need to resort to parsing the
            // tensor name.
            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
            }
            if (i_layer < 0 || i_layer >= n_layer) {
                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
            }
        }
        return std::make_pair(i_layer, n_layer);
    };

    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
    // with the quantization of the output tensor
    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
        if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
            new_type = qs.params->output_tensor_type;
        } else {
            int nx = tensor->ne[0];
            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
                new_type = GGML_TYPE_Q8_0;
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S  || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M   ||
                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
                new_type = GGML_TYPE_Q5_K;
            }
            else if (new_type != GGML_TYPE_Q8_0) {
                new_type = GGML_TYPE_Q6_K;
            }
        }
    } else if (name == "token_embd.weight") {
        if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
            new_type = qs.params->token_embedding_type;
        } else {
            if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
                new_type = GGML_TYPE_Q2_K;
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
                new_type = GGML_TYPE_IQ3_S;
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
                new_type = GGML_TYPE_IQ3_S;
            }
            else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
                     new_type == GGML_TYPE_Q4_0_8_8) {
                new_type = GGML_TYPE_Q4_0;
            }
        }
    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
               ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
        if (name.find("attn_v.weight") != std::string::npos) {
            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
            else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
            ++qs.i_attention_wv;
        }
        else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (name.find("ffn_down") != std::string::npos) {
            if (qs.i_ffn_down < qs.n_ffn_down/8) {
                new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
            }
            ++qs.i_ffn_down;
        }
        else if (name.find("attn_output.weight") != std::string::npos) {
            if (qs.model.hparams.n_expert == 8) {
                new_type = GGML_TYPE_Q5_K;
            } else {
                if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
            }
        }
    } else if (name.find("attn_v.weight") != std::string::npos) {
        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
        }
        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
            new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
            new_type = GGML_TYPE_Q5_K;
        }
        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
        if (qs.model.type == MODEL_70B) {
            // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
            // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
            // nearly negligible increase in model size by quantizing this tensor with more bits:
            if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
        }
        if (qs.model.hparams.n_expert == 8) {
            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
            // TODO: explore better strategies
            new_type = GGML_TYPE_Q8_0;
        }
        ++qs.i_attention_wv;
    } else if (name.find("attn_k.weight") != std::string::npos) {
        if (qs.model.hparams.n_expert == 8) {
            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
            // TODO: explore better strategies
            new_type = GGML_TYPE_Q8_0;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
            new_type = GGML_TYPE_IQ3_XXS;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
            new_type = GGML_TYPE_IQ2_S;
        }
    } else if (name.find("attn_q.weight") != std::string::npos) {
        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
            new_type = GGML_TYPE_IQ3_XXS;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
            new_type = GGML_TYPE_IQ2_S;
        }
    } else if (name.find("ffn_down") != std::string::npos) {
        auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
        int i_layer = info.first, n_layer = info.second;
        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
            if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
            new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
            new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
                     : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
                     : GGML_TYPE_Q3_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
                    (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
            new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
            if (arch == LLM_ARCH_FALCON) {
                new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
                           use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
            } else {
                if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
            }
        }
        else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
            new_type = GGML_TYPE_Q5_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
            new_type = GGML_TYPE_Q5_K;
        }
        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
                && qs.has_imatrix && i_layer < n_layer/8) {
            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
        }
        ++qs.i_ffn_down;
    } else if (name.find("attn_output.weight") != std::string::npos) {
        if (arch != LLM_ARCH_FALCON) {
            if (qs.model.hparams.n_expert == 8) {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
                    ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL  ||
                    ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S  ||
                    ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
                    new_type = GGML_TYPE_Q5_K;
                }
            } else {
                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   ) new_type = GGML_TYPE_Q3_K;
                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  ) new_type = GGML_TYPE_Q4_K;
            }
        } else {
            if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
        }
    }
    else if (name.find("attn_qkv.weight") != std::string::npos) {
        if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
            new_type = GGML_TYPE_Q4_K;
        }
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
    }
    else if (name.find("ffn_gate") != std::string::npos) {
        auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
        int i_layer = info.first, n_layer = info.second;
        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
            new_type = GGML_TYPE_IQ3_XXS;
        }
        ++qs.i_ffn_gate;
    }
    else if (name.find("ffn_up") != std::string::npos) {
        auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
        int i_layer = info.first, n_layer = info.second;
        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
            new_type = GGML_TYPE_IQ3_XXS;
        }
        ++qs.i_ffn_up;
    }

    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
    //}
    // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
    //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
    //}
    // This can be used to reduce the size of the Q5_K_S model.
    // The associated PPL increase is fully in line with the size reduction
    //else {
    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
    //}
    bool convert_incompatible_tensor = false;
    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
        new_type == GGML_TYPE_IQ1_M) {
        int nx = tensor->ne[0];
        int ny = tensor->ne[1];
        if (nx % QK_K != 0) {
            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
            convert_incompatible_tensor = true;
        } else {
            ++qs.n_k_quantized;
        }
    }
    if (convert_incompatible_tensor) {
        switch (new_type) {
            case GGML_TYPE_IQ2_XXS:
            case GGML_TYPE_IQ2_XS:
            case GGML_TYPE_IQ2_S:
            case GGML_TYPE_IQ3_XXS:
            case GGML_TYPE_IQ3_S:
            case GGML_TYPE_IQ1_S:
            case GGML_TYPE_IQ1_M:
            case GGML_TYPE_Q2_K:
            case GGML_TYPE_Q3_K:
            case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
            case GGML_TYPE_Q4_K:   new_type = GGML_TYPE_Q5_0;   break;
            case GGML_TYPE_Q5_K:   new_type = GGML_TYPE_Q5_1;   break;
            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
        }
        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
        ++qs.n_fallback;
    }

    return new_type;
}

and a lot of new tensor names and MoE architectures are now falling through the original logic of this function as a result.

There is another ongoing PR (or issue) that proposes this function be refactored.

EDIT: Found it #8736.

@compilade
Copy link
Collaborator

compilade commented Aug 8, 2024

Note that MoE tensors are handled by this PR in the same way FFN tensors are handled:

name.find("ffn_up") != std::string::npos

also matches the stacked expert tensors

{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },

And yes, the tensor names are renamed by the convert script to fit what src/llama.cpp expects.

@Nexesenex
Copy link
Contributor Author

Nexesenex commented Aug 8, 2024

Considering the informations given by Compilade, I don't think there's anything more to do for the present PR, Jukofyork.
The discussions in #8736 will likely end-up in a PR refactoring the overall quantization system, so I see little interest to dig furthermore until such refactor is decided and done.

Thus, I consider my PR ready for review.

@slaren
Copy link
Collaborator

slaren commented Aug 8, 2024

I am not sure about these changes, because at best this is a stopgap measure until we have a way to fully define custom quantization schemes. Maybe it could remain as a PR or a fork for people who really want to do this right now, rather than merging code that will become obsolete (hopefully soon), but will still add to the maintenance overhead. I am also somewhat wary about distributing quants with wildly different quantization schemes under the same ftype, because it may create confusion. It might be better to add a new ftype to use with custom schemes.

@Nexesenex Nexesenex marked this pull request as draft August 9, 2024 12:05
@Nexesenex
Copy link
Contributor Author

@slaren: I agree with your reasoning.

This is a stopgap measure indeed, and I neglected to consider the mess it could create with people sharing vastly modified quants without proper labels or the maintenance.

I converted the PR as a draft for those who want to use it, and will try to create the relevant FTYPE accordingly to your imput.

Otherwise, I'll wait for a potential revamp of the quantization schemes and customization options.

@ggerganov ggerganov added the demo Demonstrate some concept or idea, not intended to be merged label Aug 9, 2024
@Nexesenex
Copy link
Contributor Author

@slaren : I created the required FTYPE. The custom quants work only with the FTYPE "CQS", and will simply be bypassed if the user quantizing a model selects a regular quantization scheme.

Also, it reduces drastically the replacements made in the code due to the reversal of the new level of indentation of the tensor-per-tensor quantization rules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged examples Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants