Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Q8_0 and Q4_0 with Bf16 delta values #7497

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

Srihari-mcw
Copy link
Contributor

@Srihari-mcw Srihari-mcw commented May 23, 2024

  • The following PR introduces Q4_0 and Q8_0 quantizations where the delta values are stored in BF16 instead of FP16.
  • The PR introduces relevant quantization, dequantization and dot product functions
  • The PR introduces gemm4xN and gemmMx4 templated functions for the gemm functions of relevant dimensions for the new quantizations
  • The new quantizations in the PR uses _mm_dpbf16_ps instruction for delta value multiplication and usage of it in the mul mat result computation. Loop unrolling is done so as to extract and use the resultant delta multiplication outpus. The existing Q4_0 and Q8_0 quantizations with FP16 delta uses a combination of lookup table and set instruction for delta values multiplication
  • The above changes help the Q4_0 and Q8_0 quantizations in new format (with BF16 delta) gain better performance with prompt processing than the corresponding counterparts with FP16 delta
  • The performance results are attached as follows in the default prompt speedup enabled path :

GCC Linux :

Q8_0 Model :

model size params backend threads test t/s speedup Commit id
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 pp 512 55.61 ± 0.07 d0af2a19
llama 7B Q8_0_B16 6.67 GiB 6.74 B CPU 6 pp 512 65.36 ± 0.08 17.53% d0af2a19
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 tg 128 8.08 ± 0.07 d0af2a19
llama 7B Q8_0_B16 6.67 GiB 6.74 B CPU 6 tg 128 8.1 ± 0.01 0.25% d0af2a19

Q4_0 Model :

model size params backend threads test t/s speedup Commit id
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 pp 512 43.17 ± 0.10 d0af2a19
llama 7B Q4_0_B16 3.56 GiB 6.74 B CPU 6 pp 512 57.46 ± 0.09 33.10% d0af2a19
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 tg 128 14.57 ± 0.02 d0af2a19
llama 7B Q4_0_B16 3.56 GiB 6.74 B CPU 6 tg 128 14.59 ± 0.01 0.14% d0af2a19

MSVC Windows :

Q8_0 Model :

model size params backend threads test t/s speedup Commit id
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 pp 512 35.02 ± 0.41 d0af2a19
llama 7B Q8_0_B16 6.67 GiB 6.74 B CPU 6 pp 512 45.82 ± 0.59 30.83% d0af2a19
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 tg 128 8.01 ± 0.02 d0af2a19
llama 7B Q8_0_B16 6.67 GiB 6.74 B CPU 6 tg 128 8.02 ± 0.02 0.12% d0af2a19

Q4_0 Model :

model size params backend threads test t/s speedup Commit id
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 pp 512 26.4 ± 0.44 d0af2a19
llama 7B Q4_0_B16 3.56 GiB 6.74 B CPU 6 pp 512 42.57 ± 0.48 61.25% d0af2a19
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 tg 128 14.12 ± 0.17 d0af2a19
llama 7B Q4_0_B16 3.56 GiB 6.74 B CPU 6 tg 128 14.15 ± 0.09 0.21% d0af2a19

The PR was tested in AMD Raphael 7600X which supports AVX512_BF16. AVX512_BF16 was enabled in Windows by cmake .. -DLLAMA_AVX512_BF16=ON

The models were quantized and tested from meta-llama2 7B model - https://huggingface.co/meta-llama/Llama-2-7b

@github-actions github-actions bot added examples python python script changes ggml changes relating to the ggml tensor library for machine learning labels May 23, 2024
@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 23, 2024
Copy link
Contributor

github-actions bot commented May 23, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 550 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8513.51ms p(95)=20846.42ms fails=, finish reason: stop=498 truncated=52
  • Prompt processing (pp): avg=94.11tk/s p(95)=408.86tk/s
  • Token generation (tg): avg=47.75tk/s p(95)=48.07tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=quants_q8_0_q4_0_with_b16_delta commit=46c0cd78ef103d83cbec3d8d0dbd36574f0dd889

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 550 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716814082 --> 1716814712
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 619.03, 619.03, 619.03, 619.03, 619.03, 894.55, 894.55, 894.55, 894.55, 894.55, 894.72, 894.72, 894.72, 894.72, 894.72, 907.02, 907.02, 907.02, 907.02, 907.02, 955.11, 955.11, 955.11, 955.11, 955.11, 942.74, 942.74, 942.74, 942.74, 942.74, 931.12, 931.12, 931.12, 931.12, 931.12, 952.76, 952.76, 952.76, 952.76, 952.76, 944.94, 944.94, 944.94, 944.94, 944.94, 920.98, 920.98, 920.98, 920.98, 920.98, 942.37, 942.37, 942.37, 942.37, 942.37, 931.51, 931.51, 931.51, 931.51, 931.51, 935.62, 935.62, 935.62, 935.62, 935.62, 913.85, 913.85, 913.85, 913.85, 913.85, 902.19, 902.19, 902.19, 902.19, 902.19, 901.48, 901.48, 901.48, 901.48, 901.48, 899.68, 899.68, 899.68, 899.68, 899.68, 886.58, 886.58, 886.58, 886.58, 886.58, 886.64, 886.64, 886.64, 886.64, 886.64, 887.31, 887.31, 887.31, 887.31, 887.31, 895.94, 895.94, 895.94, 895.94, 895.94, 896.22, 896.22, 896.22, 896.22, 896.22, 862.11, 862.11, 862.11, 862.11, 862.11, 861.11, 861.11, 861.11, 861.11, 861.11, 862.14, 862.14, 862.14, 862.14, 862.14, 861.14, 861.14, 861.14, 861.14, 861.14, 860.12, 860.12, 860.12, 860.12, 860.12, 859.76, 859.76, 859.76, 859.76, 859.76, 861.0, 861.0, 861.0, 861.0, 861.0, 863.86, 863.86, 863.86, 863.86, 863.86, 861.65, 861.65, 861.65, 861.65, 861.65, 865.82, 865.82, 865.82, 865.82, 865.82, 877.32, 877.32, 877.32, 877.32, 877.32, 878.09, 878.09, 878.09, 878.09, 878.09, 876.97, 876.97, 876.97, 876.97, 876.97, 853.98, 853.98, 853.98, 853.98, 853.98, 853.0, 853.0, 853.0, 853.0, 853.0, 857.49, 857.49, 857.49, 857.49, 857.49, 859.4, 859.4, 859.4, 859.4, 859.4, 864.06, 864.06, 864.06, 864.06, 864.06, 858.37, 858.37, 858.37, 858.37, 858.37, 857.96, 857.96, 857.96, 857.96, 857.96, 857.29, 857.29, 857.29, 857.29, 857.29, 856.47, 856.47, 856.47, 856.47, 856.47, 847.67, 847.67, 847.67, 847.67, 847.67, 851.69, 851.69, 851.69, 851.69, 851.69, 853.55, 853.55, 853.55, 853.55, 853.55, 856.63, 856.63, 856.63, 856.63, 856.63, 858.37, 858.37, 858.37, 858.37, 858.37, 860.71, 860.71, 860.71, 860.71, 860.71, 864.5, 864.5, 864.5, 864.5, 864.5, 863.46, 863.46, 863.46, 863.46, 863.46, 866.33, 866.33, 866.33, 866.33, 866.33, 866.47, 866.47, 866.47, 866.47, 866.47, 867.93, 867.93, 867.93, 867.93, 867.93, 867.76, 867.76, 867.76, 867.76, 867.76, 868.87, 868.87, 868.87, 868.87, 868.87, 869.55, 869.55, 869.55, 869.55, 869.55, 872.89, 872.89, 872.89, 872.89, 872.89, 872.76, 872.76, 872.76, 872.76, 872.76, 872.76]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 550 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716814082 --> 1716814712
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 39.15, 39.15, 39.15, 39.15, 39.15, 28.58, 28.58, 28.58, 28.58, 28.58, 27.55, 27.55, 27.55, 27.55, 27.55, 29.3, 29.3, 29.3, 29.3, 29.3, 30.68, 30.68, 30.68, 30.68, 30.68, 30.79, 30.79, 30.79, 30.79, 30.79, 31.67, 31.67, 31.67, 31.67, 31.67, 32.6, 32.6, 32.6, 32.6, 32.6, 32.6, 32.6, 32.6, 32.6, 32.6, 32.89, 32.89, 32.89, 32.89, 32.89, 33.19, 33.19, 33.19, 33.19, 33.19, 33.27, 33.27, 33.27, 33.27, 33.27, 32.82, 32.82, 32.82, 32.82, 32.82, 32.36, 32.36, 32.36, 32.36, 32.36, 31.15, 31.15, 31.15, 31.15, 31.15, 29.94, 29.94, 29.94, 29.94, 29.94, 29.95, 29.95, 29.95, 29.95, 29.95, 30.04, 30.04, 30.04, 30.04, 30.04, 29.88, 29.88, 29.88, 29.88, 29.88, 30.02, 30.02, 30.02, 30.02, 30.02, 30.31, 30.31, 30.31, 30.31, 30.31, 30.37, 30.37, 30.37, 30.37, 30.37, 30.62, 30.62, 30.62, 30.62, 30.62, 30.41, 30.41, 30.41, 30.41, 30.41, 30.79, 30.79, 30.79, 30.79, 30.79, 30.83, 30.83, 30.83, 30.83, 30.83, 30.59, 30.59, 30.59, 30.59, 30.59, 30.91, 30.91, 30.91, 30.91, 30.91, 31.02, 31.02, 31.02, 31.02, 31.02, 31.33, 31.33, 31.33, 31.33, 31.33, 31.46, 31.46, 31.46, 31.46, 31.46, 31.55, 31.55, 31.55, 31.55, 31.55, 31.46, 31.46, 31.46, 31.46, 31.46, 31.33, 31.33, 31.33, 31.33, 31.33, 31.09, 31.09, 31.09, 31.09, 31.09, 30.68, 30.68, 30.68, 30.68, 30.68, 30.83, 30.83, 30.83, 30.83, 30.83, 30.98, 30.98, 30.98, 30.98, 30.98, 31.16, 31.16, 31.16, 31.16, 31.16, 31.31, 31.31, 31.31, 31.31, 31.31, 31.04, 31.04, 31.04, 31.04, 31.04, 31.04, 31.04, 31.04, 31.04, 31.04, 30.79, 30.79, 30.79, 30.79, 30.79, 30.75, 30.75, 30.75, 30.75, 30.75, 28.87, 28.87, 28.87, 28.87, 28.87, 28.92, 28.92, 28.92, 28.92, 28.92, 28.99, 28.99, 28.99, 28.99, 28.99, 29.07, 29.07, 29.07, 29.07, 29.07, 29.21, 29.21, 29.21, 29.21, 29.21, 29.33, 29.33, 29.33, 29.33, 29.33, 29.27, 29.27, 29.27, 29.27, 29.27, 29.2, 29.2, 29.2, 29.2, 29.2, 29.17, 29.17, 29.17, 29.17, 29.17, 29.09, 29.09, 29.09, 29.09, 29.09, 29.2, 29.2, 29.2, 29.2, 29.2, 29.36, 29.36, 29.36, 29.36, 29.36, 29.42, 29.42, 29.42, 29.42, 29.42, 29.55, 29.55, 29.55, 29.55, 29.55, 29.59, 29.59, 29.59, 29.59, 29.59, 29.61, 29.61, 29.61, 29.61, 29.61, 29.58]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 550 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716814082 --> 1716814712
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.38, 0.38, 0.38, 0.38, 0.38, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.33, 0.33, 0.33, 0.33, 0.33, 0.24, 0.24, 0.24, 0.24, 0.24, 0.44, 0.44, 0.44, 0.44, 0.44, 0.33, 0.33, 0.33, 0.33, 0.33, 0.26, 0.26, 0.26, 0.26, 0.26, 0.11, 0.11, 0.11, 0.11, 0.11, 0.31, 0.31, 0.31, 0.31, 0.31, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.13, 0.13, 0.13, 0.13, 0.13, 0.2, 0.2, 0.2, 0.2, 0.2, 0.32, 0.32, 0.32, 0.32, 0.32, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.27, 0.27, 0.27, 0.27, 0.27, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.33, 0.33, 0.33, 0.33, 0.33, 0.31, 0.31, 0.31, 0.31, 0.31, 0.19, 0.19, 0.19, 0.19, 0.19, 0.1, 0.1, 0.1, 0.1, 0.1, 0.09, 0.09, 0.09, 0.09, 0.09, 0.12, 0.12, 0.12, 0.12, 0.12, 0.26, 0.26, 0.26, 0.26, 0.26, 0.55, 0.55, 0.55, 0.55, 0.55, 0.63, 0.63, 0.63, 0.63, 0.63, 0.69, 0.69, 0.69, 0.69, 0.69, 0.8, 0.8, 0.8, 0.8, 0.8, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.29, 0.29, 0.29, 0.29, 0.29, 0.18, 0.18, 0.18, 0.18, 0.18, 0.26, 0.26, 0.26, 0.26, 0.26, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.29]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 550 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716814082 --> 1716814712
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0]
                    
Loading

@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from b9a5d91 to a9eaa9e Compare May 24, 2024 09:57
@Srihari-mcw
Copy link
Contributor Author

Additional Note : ggml_vec_dot_q4_0_q8_0 does not contain any change. An additional function ggml_vec_dot_q4_0_b16_q8_0_b16 was added for the new Q4_0_B16 type model just after ggml_vec_dot_q4_0_q8_0. Github however shows difference in the "Files modified" section for ggml_vec_dot_q4_0_q8_0 function in ggml-quants.c

@@ -17,6 +17,7 @@ struct quant_option {

static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 3.56G, +0.2166 ppl @ LLaMA-v1-7B", },
{ "Q4_0_B16", LLAMA_FTYPE_MOSTLY_Q4_0_B16, " 3.56G, 5.9624 +/- 0.03348 ppl @ LLaMA-v2-7B", },
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The perplexity score mentioned here was derived from running perplexity.exe with Q4_0_B16 quantized model for meta llama2 7B model. Was unsure of methodology for the other ppl scores put across here. Kindly share feedback if the score put here needs to be modified. Thanks

@sorasoras
Copy link

That's great.
Could this applied to GPUs as well?

@mofosyne mofosyne added the Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes label May 25, 2024
@mofosyne
Copy link
Collaborator

You may want to rebase this on top of master so the CI can completely pass (a CI fault has been bypassed for now)

@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from 135cec2 to 46c0cd7 Compare May 27, 2024 12:09
@Srihari-mcw
Copy link
Contributor Author

@sorasoras , The optimization changes were primarily done with CPU SIMD Instructions and was tested on a CPU backend. Thanks

@Srihari-mcw
Copy link
Contributor Author

@mofosyne , The branch was rebased on top of current master branch. Thanks

@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from 46c0cd7 to 138cd22 Compare July 7, 2024 02:15
@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from 2fd0a10 to eb1116a Compare July 15, 2024 09:17
@mofosyne mofosyne requested a review from ngxson July 30, 2024 09:22
@ngxson
Copy link
Collaborator

ngxson commented Jul 30, 2024

Nice idea. I don't have much time to look into details right now, but the overall the implementation looks good. I'll need more time to test it.

Also, I'm quite interested in the performance test on AVX256 because AFAIK _mm_dpbf16_ps is only available with AVX512.

CC @jart for the tinyblas part

@ngxson ngxson requested review from ggerganov and slaren July 30, 2024 09:34
Comment on lines 1097 to 1098
Q4_0_B16 = 31
Q8_0_B16 = 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Q4_0_B16 = 31
Q8_0_B16 = 32
Q4_0_B16 = 34
Q8_0_B16 = 35

These should be the same as in ggml.h.

jart added a commit to Mozilla-Ocho/llamafile that referenced this pull request Jul 30, 2024
This change was written by Srihari-mcw. These new quants are the same as
Q8_0 and Q4_0 except BF16 is used (instead of F16) as the scaling scalar

See ggerganov/llama.cpp#7497
@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from eb1116a to e9305da Compare July 30, 2024 17:57
@slaren
Copy link
Collaborator

slaren commented Jul 30, 2024

It may be worth checking if there is some optimization issue with the implementation of these quants, because it is hard to imagine that a single fp16 to fp32 conversion per block could be so expensive. Currently, F16C is not used to convert between fp16 and fp32 because tests showed it to be slower than a lookup table. I suspect this is because the instruction has a significant latency, but it should be possible to hide most of this latency by reordering the instructions and unrolling the loops.

My general view is that we already have more quant types than we should, each quant type is a maintenance burden for us and for the backend developers, it makes the choice harder for the users, and it adds to the work of the people quantizing the models. We should avoid adding new types unless strictly necessary, and we should look into removing some of the outdated formats that have been effectively replaced by more efficient alternatives (which likely would include Q4_0).

@jart
Copy link
Contributor

jart commented Jul 30, 2024

I've created a branch on the llamafile repository where I've imported this pull request.

I renamed your quantization formats Q8_B and Q4_B which is a much more manageable name.

Here are your benchmarks on a CPU that supports AVX512F BF16.

cpu_info model_filename size test t/s
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.BF16 13.50 GiB pp512 322.26
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.BF16 13.50 GiB tg16 15.22
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.F16 13.50 GiB pp512 230.85
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.F16 13.50 GiB tg16 15.10
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q8_0 7.17 GiB pp512 216.80
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q8_0 7.17 GiB tg16 22.07
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q8_B 7.17 GiB pp512 266.40
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q8_B 7.17 GiB tg16 25.58
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q6_K 5.54 GiB pp512 347.76
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q6_K 5.54 GiB tg16 30.05
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_K_M 4.07 GiB pp512 362.83
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_K_M 4.07 GiB tg16 37.51
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_0 3.83 GiB pp512 253.11
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_0 3.83 GiB tg16 39.01
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_B 3.83 GiB pp512 243.45
AMD Ryzen Threadripper PRO 7995WX (znver4) Mistral-7B-Instruct-v0.3.Q4_B 3.83 GiB tg16 33.82

While Q8_B seems to go a little bit faster than Q8_0 there was a decline in performance for Q4_B. Prompt processing speed for Q4/Q8 is much slower than BF16 on Zen4. Q4 is a legacy quant. The K quants do more math and manage to go much faster.

I don't think we really have much to gain in terms of performance here. Each q4/q8 block has a single f16 scalar. SIMD doesn't help when you're dealing with scalars. This change goes too far out of its way to call _mm_dpbf16_ps(). Even with optimal use the instruction isn't that much faster than the alternative (simply left shifting by 16).

@Srihari-mcw
Copy link
Contributor Author

Srihari-mcw commented Aug 7, 2024

Hi @slaren, @jart - Recently, similar changes of parallel delta value multiplication combined with loop unrolling for 4xN and Mx4 dimensions were tried with existing quantization types with FP16 delta values and we were able to observe gains in performance in our platforms. The corresponding PR #8908 is attached here and the performance details are also attached here for your reference. Please have a look on the same. Thanks

GCC Linux :

Meta Llama2 7B model:

Q4_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 pp 512 43.79 ± 0.08 7e72aa74 Base commit
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 pp 512 59.37 ± 0.08 35.58% cdf3a251 Commit with PR changes
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 tg 128 14.65 ± 0.01 7e72aa74 Base commit
llama 7B Q4_0 3.56 GiB 6.74 B CPU 6 tg 128 14.51 ± 0.00 -0.96% cdf3a251 Commit with PR changes

Q8_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 pp 512 56.87 + 0.06 7e72aa74 Base commit
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 pp 512 68.03 + 0.13 19.69% cdf3a251 Commit with PR changes
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 tg 128 8.12 ± 0.00 7e72aa74 Base commit
llama 7B Q8_0 6.67 GiB 6.74 B CPU 6 tg 128 8.12 ± 0.00 0.00% cdf3a251 Commit with PR changes

Mistral-7B-Instruct-v0.3 model:

Q4_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 pp 512 40.96 ± 0.05 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 pp 512 55.71 ± 0.11 36.01% cdf3a251 Commit with PR changes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 tg 128 13.81 ± 0.01 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 tg 128 13.66 ± 0.00 -1.09% cdf3a251 Commit with PR changes

Q8_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 pp 512 53.34 + 0.04 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 pp 512 63.64 + 0.07 19.31% cdf3a251 Commit with PR changes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 tg 128 7.59 ± 0.00 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 tg 128 7.60 ± 0.00 0.13% cdf3a251 Commit with PR changes

GCC Version = 12.3

The PR was tested in AMD Raphael 7600X which supports the following flags by default :

AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1|

Original Unquantized Models :

Llama2 7B : https://huggingface.co/meta-llama/Llama-2-7b
Mistral 7B Instruct v0.3 : https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3

@Srihari-mcw
Copy link
Contributor Author

Srihari-mcw commented Aug 8, 2024

The PR #8908 was also tested in an AMD Ryzen ThreadRipper PRO 5995WX machine. Test Results are attached below along with flags supported and other details

Performance Results in AMD Ryzen Threadripper PRO 5995WX

GCC Linux :

Mistral-7B-Instruct-v0.3 model:

Q4_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 pp 512 189.30 ± 0.31 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 pp 512 210.26 ± 0.32 11.07% cdf3a251 Commit with PR changes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 tg 128 33.74 ± 0.04 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 tg 128 33.77 ± 0.05 0.09% cdf3a251 Commit with PR changes

Q8_0 Model :

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 pp 512 214.93 + 0.25 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 pp 512 241.85 + 0.47 12.53% cdf3a251 Commit with PR changes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 tg 128 19.83 ± 0.01 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 tg 128 19.74 ± 0.00 0.13% cdf3a251 Commit with PR changes

GCC Version = 12.3

The machine supports the following flags by default :

| AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |

Original Unquantized Models :

Llama2 7B : https://huggingface.co/meta-llama/Llama-2-7b
Mistral 7B Instruct v0.3 : https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3

@Srihari-mcw Srihari-mcw force-pushed the quants_q8_0_q4_0_with_b16_delta branch from e9305da to 43c7be5 Compare August 20, 2024 05:20
@jart
Copy link
Contributor

jart commented Aug 20, 2024

That's an impressive achievement. Congratulations. Thanks for showing us the numbers. Now we just have to decide is adding and maintaining a new quantization format forever is worth it to make znver4 go 13% faster :-)

@Srihari-mcw
Copy link
Contributor Author

Srihari-mcw commented Aug 20, 2024

That's an impressive achievement. Congratulations. Thanks for showing us the numbers. Now we just have to decide is adding and maintaining a new quantization format forever is worth it to make znver4 go 13% faster :-)

We want to reiterate that, in PR #8908, we have retained the original quantization types with FP16 delta (The PR #8908 does not contain a new quantization format). In Zen 4 (Raphael 7600X) the gains observed for prompt processing stands at approx 35% and 20% for Q4_0 and Q8_0. In TR 5995WX, the gains observed for prompt processing stands at approx 11% and 12.5% for Q4_0 and Q8_0. For more info, refer to PR #8908. Thanks

Prompt Processing Test Results with PR #8908 for Mistral-7B-Instruct-v0.3 model for Q4_0 and Q8_0 models in GCC Linux 12.3

AMD Raphael 7600X (Zen 4)

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 pp 512 40.96 ± 0.05 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 6 pp 512 55.71 ± 0.11 36.01% cdf3a251 Commit with PR changes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 pp 512 53.34 + 0.04 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 6 pp 512 63.64 + 0.07 19.31% cdf3a251 Commit with PR changes

AMD Ryzen Threadripper PRO 5995WX

model size params backend threads test t/s speedup Commit id Notes
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 pp 512 189.30 ± 0.31 7e72aa74 Base commit
llama 7B Q4_0 3.83 GiB 7.25 B CPU 64 pp 512 210.26 ± 0.32 11.07% cdf3a251 Commit with PR changes
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 pp 512 214.93 + 0.25 7e72aa74 Base commit
llama 7B Q8_0 7.17 GiB 7.25 B CPU 64 pp 512 241.85 + 0.47 12.53% cdf3a251 Commit with PR changes

Notable differences in flags between ThreadRipper 5995WX and AMD Raphael 7600X : AMD Raphael 7600X supports AVX512, AVX512_VNNI, AVX512_VBMI, AVX512_BF16 whereas ThreadRipper 5995WX does not

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples ggml changes relating to the ggml tensor library for machine learning python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants