Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grammars: cache decoded token codepoints for faster sampling #6811

Closed
wants to merge 19 commits into from

Conversation

ochafik
Copy link
Collaborator

@ochafik ochafik commented Apr 21, 2024

Edit: superseded by #7587 & #7424, can't measure significant improvements from caching the decoded codepoints: closing this PR.

Cache token pieces & their decoded codepoints (for common case where there's no partial utf8 sequence prefix) in llama_context to speed up grammar sampling (#4218)

cc/ @ejones, @HanClinto

Update: while this PR still seems useful, #7424 is bringing the most performance in most cases. I've run some benchmarks that seems to show this PR will add some marginal gains upon it.

This simple example runs 1.9x faster (total time; sampling itself goes from ~16 ms per token to ~1.9 ms per token, a 9x sampling speedup):

# git remote add ochafik https://github.com/ochafik/llama.cpp.git
# git fetch ochafik
hyperfine --warmup 1 --runs 5 \
    -L branch ochafik/grammar-fast,master \
    --setup 'git checkout {branch} && make clean && make -j LLAMA_CURL=1 main' \
    'BRANCH={branch} \
        ./main --seed 12345 \
            -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
            -j '"'"'{"items": {"type": "number"}, "minItems": 10, "maxItems": 100}'"'"' \
            -p "JSON list of first 50 integers"'
Show output
Benchmark 1: BRANCH=grammar-fast \
        ./main --seed 12345 \
            -m models/phi-2.Q4_K_M.gguf \
            -j '{"items": {"type": "number"}, "minItems": 10, "maxItems": 100}' \
            -p "JSON list of first 50 integers"
  Time (mean ± σ):      3.229 s ±  0.039 s    [User: 0.621 s, System: 0.191 s]
  Range (min … max):    3.193 s …  3.290 s    5 runs
 
Benchmark 2: BRANCH=master \
        ./main --seed 12345 \
            -m models/phi-2.Q4_K_M.gguf \
            -j '{"items": {"type": "number"}, "minItems": 10, "maxItems": 100}' \
            -p "JSON list of first 50 integers"
  Time (mean ± σ):      6.398 s ±  0.073 s    [User: 3.629 s, System: 0.258 s]
  Range (min … max):    6.345 s …  6.523 s    5 runs
 
Summary
  'BRANCH=grammar-fast \
        ./main --seed 12345 \
            -m models/phi-2.Q4_K_M.gguf \
            -j '{"items": {"type": "number"}, "minItems": 10, "maxItems": 100}' \
            -p "JSON list of first 50 integers"' ran
    1.98 ± 0.03 times faster than 'BRANCH=master \
        ./main --seed 12345 \
            -m models/phi-2.Q4_K_M.gguf \
            -j '{"items": {"type": "number"}, "minItems": 10, "maxItems": 100}' \
            -p "JSON list of first 50 integers"'

master:

llama_print_timings:        load time =    2416.78 ms
llama_print_timings:      sample time =    3274.38 ms /   203 runs   (   16.13 ms per token,    62.00 tokens per second)
llama_print_timings: prompt eval time =      49.09 ms /     7 tokens (    7.01 ms per token,   142.59 tokens per second)
llama_print_timings:        eval time =    2450.08 ms /   202 runs   (   12.13 ms per token,    82.45 tokens per second)
llama_print_timings:       total time =    6037.77 ms /   209 tokens

this PR:

llama_print_timings:        load time =     183.22 ms
llama_print_timings:      sample time =     377.09 ms /   203 runs   (    1.86 ms per token,   538.33 tokens per second)
llama_print_timings: prompt eval time =      49.03 ms /     7 tokens (    7.00 ms per token,   142.77 tokens per second)
llama_print_timings:        eval time =    2411.72 ms /   202 runs   (   11.94 ms per token,    83.76 tokens per second)
llama_print_timings:       total time =    2887.39 ms /   209 tokens

Note: removed the following changes from earlier version of this PR

llama.cpp Outdated Show resolved Hide resolved
@HanClinto
Copy link
Collaborator

I've been out of it for a few days, but finally digging back in. GREAT work on all of this, @ochafik !!

Cache token pieces & their decoded codepoints in llama_sample_grammar

I'm still trying to read through and understand this -- I can't claim to have a good handle on this yet. Are you able to summarize why this optimization works?

Non-quadratic llama_grammar_copy for speculative decoding

Because llama_grammar_copy is only used by speculative.cpp, I wonder if this change should be left out for now...? Unless the changes to retaining cached codepoints makes it easier to keep it in here, then let's keep it in here for now. Instead of profiling on speculative.cpp, might be easiest to write unit tests around this?

@ochafik
Copy link
Collaborator Author

ochafik commented Apr 24, 2024

I've been out of it for a few days, but finally digging back in

@HanClinto welcome back, hope the robotics competition went well!

Non-quadratic llama_grammar_copy for speculative decoding

Because llama_grammar_copy is only used by speculative.cpp, I wonder if this change should be left out for now...?

Good point, reverted / will send separately 👌 (one less thing to check for this PR :-D)

@ochafik
Copy link
Collaborator Author

ochafik commented Apr 24, 2024

Cache token pieces & their decoded codepoints in llama_sample_grammar

I'm still trying to read through and understand this -- I can't claim to have a good handle on this yet. Are you able to summarize why this optimization works?

@HanClinto The code used to decode each token in the context of a potentially incomplete UTF-8 sequence (say, if the first token was a byte token with the 110xxxxx bit mask). This is very rare in practice, so tokens end up being decoded to unicode codepoints w/o any partial utf bytes as prefix (i.e. always same codepoints for a given token). I'm caching said non-partial decoding, which saves a lot of time (profiling the code showed lots of time was spent here). Also caching the token piece although that bit is arguably less useful (only in case of partial utf decoding, maybe more frequent for non-latin languages?).

I'll add comments about this :-)

@HanClinto
Copy link
Collaborator

I've been out of it for a few days, but finally digging back in

@HanClinto welcome back, hope the robotics competition went well!

Thanks -- it went really well! The kids all did great! 10 kids with 9 robots in 3 different categories. We took home two special nominations, two gold trophies, a silver, and a bronze. Tried a new thing this year where our team jerseys were custom-printed lab coats for all of the kids, and that was a ton of fun. :D It's a pretty big competition (here's a highlight reel from last year), and the kids are excited for next year.

Anyways, enough fluff. That's what happens when a project doesn't have a Discord server -- we clutter up the PRs with such conversations. :D

Non-quadratic llama_grammar_copy for speculative decoding

Because llama_grammar_copy is only used by speculative.cpp, I wonder if this change should be left out for now...?

Good point, reverted / will send separately 👌 (one less thing to check for this PR :-D)

Sounds great! 👍

I know I've mentioned it a few times, but I have a rather complex SQL grammar that I need to dig out and dust off that would help stretch this stuff a bit more.

@ochafik ochafik changed the title grammars: cache decoded token codepoints & early exit in candidates rejection (up to 10x faster sampling) grammars: cache decoded token codepoints & early exit in candidates rejection (faster sampling) Apr 25, 2024
@ejones
Copy link
Collaborator

ejones commented Apr 26, 2024

Hey, wanted to drop in quick - love this work @ochafik! Have been following along the discussions from the notifications. Hope to have some time in the next few days to take a deeper look.

I took a super quick look and one surface level thought I had was: should the token decoding cache be on the llama_context (although I know it's already overloaded)? Decoding tokens is driven by the model vocabulary rather than anything about the grammar. And technically I think as written the same grammar struct could be applied to different models.

Other than that, I need to have a think about UTF-8 decoding and that multibyte sequences with this.

@ochafik
Copy link
Collaborator Author

ochafik commented Apr 27, 2024

Thanks @ejones !

should the token decoding cache be on the llama_context (although I know it's already overloaded)? Decoding tokens is driven by the model vocabulary rather than anything about the grammar.

Ahh, I wondered about this, guess I took the lazy route. If we moved this to the context we'd probably want to build it upfront to avoid concurrency issues in server arising from building the cache lazily. It would have to be done regardless of whether a grammar will later be used, which may be slightly wasteful, but then it would make using a new grammar marginally faster every time, so probably worth it. Will look into it!

And technically I think as written the same grammar struct could be applied to different models.

Oh indeed, for speculative decoding I suppose the two models could have (ever so slightly) different tokenizers 😓

llama_context, hold on tight, more things coming your way :-D

Copy link
Contributor

github-actions bot commented Apr 28, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 562 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8282.7ms p(95)=20313.53ms fails=, finish reason: stop=513 truncated=49
  • Prompt processing (pp): avg=98.38tk/s p(95)=411.48tk/s
  • Token generation (tg): avg=32.36tk/s p(95)=47.34tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=grammar-fast commit=6a9b626ba5ef3edb60edffd68ff0970a5d236798

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 562 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717984815 --> 1717985435
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 493.68, 493.68, 493.68, 493.68, 493.68, 729.33, 729.33, 729.33, 729.33, 729.33, 657.05, 657.05, 657.05, 657.05, 657.05, 679.91, 679.91, 679.91, 679.91, 679.91, 717.57, 717.57, 717.57, 717.57, 717.57, 719.19, 719.19, 719.19, 719.19, 719.19, 744.2, 744.2, 744.2, 744.2, 744.2, 747.92, 747.92, 747.92, 747.92, 747.92, 768.46, 768.46, 768.46, 768.46, 768.46, 771.99, 771.99, 771.99, 771.99, 771.99, 795.27, 795.27, 795.27, 795.27, 795.27, 807.0, 807.0, 807.0, 807.0, 807.0, 834.92, 834.92, 834.92, 834.92, 834.92, 838.82, 838.82, 838.82, 838.82, 838.82, 843.26, 843.26, 843.26, 843.26, 843.26, 848.02, 848.02, 848.02, 848.02, 848.02, 845.85, 845.85, 845.85, 845.85, 845.85, 863.65, 863.65, 863.65, 863.65, 863.65, 857.81, 857.81, 857.81, 857.81, 857.81, 864.41, 864.41, 864.41, 864.41, 864.41, 868.21, 868.21, 868.21, 868.21, 868.21, 869.06, 869.06, 869.06, 869.06, 869.06, 875.67, 875.67, 875.67, 875.67, 875.67, 874.16, 874.16, 874.16, 874.16, 874.16, 875.24, 875.24, 875.24, 875.24, 875.24, 873.09, 873.09, 873.09, 873.09, 873.09, 872.91, 872.91, 872.91, 872.91, 872.91, 872.06, 872.06, 872.06, 872.06, 872.06, 873.13, 873.13, 873.13, 873.13, 873.13, 872.71, 872.71, 872.71, 872.71, 872.71, 870.66, 870.66, 870.66, 870.66, 870.66, 874.75, 874.75, 874.75, 874.75, 874.75, 885.74, 885.74, 885.74, 885.74, 885.74, 890.52, 890.52, 890.52, 890.52, 890.52, 891.76, 891.76, 891.76, 891.76, 891.76, 898.69, 898.69, 898.69, 898.69, 898.69, 896.0, 896.0, 896.0, 896.0, 896.0, 894.77, 894.77, 894.77, 894.77, 894.77, 897.1, 897.1, 897.1, 897.1, 897.1, 898.98, 898.98, 898.98, 898.98, 898.98, 908.08, 908.08, 908.08, 908.08, 908.08, 882.87, 882.87, 882.87, 882.87, 882.87, 880.86, 880.86, 880.86, 880.86, 880.86, 879.44, 879.44, 879.44, 879.44, 879.44, 879.51, 879.51, 879.51, 879.51, 879.51, 873.89, 873.89, 873.89, 873.89, 873.89, 872.29, 872.29, 872.29, 872.29, 872.29, 872.38, 872.38, 872.38, 872.38, 872.38, 875.61, 875.61, 875.61, 875.61, 875.61, 879.34, 879.34, 879.34, 879.34, 879.34, 880.98, 880.98, 880.98, 880.98, 880.98, 880.48, 880.48, 880.48, 880.48, 880.48, 886.38, 886.38, 886.38, 886.38, 886.38, 885.46, 885.46, 885.46, 885.46, 885.46, 886.54, 886.54, 886.54, 886.54, 886.54, 887.67, 887.67, 887.67, 887.67, 887.67, 887.2, 887.2, 887.2, 887.2, 887.2, 887.82, 887.82, 887.82, 887.82, 887.82, 890.7, 890.7, 890.7, 890.7, 890.7, 890.23, 890.23, 890.23, 890.23, 890.23, 888.75]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 562 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717984815 --> 1717985435
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 41.95, 41.95, 41.95, 41.95, 41.95, 44.95, 44.95, 44.95, 44.95, 44.95, 31.11, 31.11, 31.11, 31.11, 31.11, 31.39, 31.39, 31.39, 31.39, 31.39, 32.18, 32.18, 32.18, 32.18, 32.18, 32.68, 32.68, 32.68, 32.68, 32.68, 33.04, 33.04, 33.04, 33.04, 33.04, 33.69, 33.69, 33.69, 33.69, 33.69, 34.09, 34.09, 34.09, 34.09, 34.09, 34.45, 34.45, 34.45, 34.45, 34.45, 34.37, 34.37, 34.37, 34.37, 34.37, 33.88, 33.88, 33.88, 33.88, 33.88, 33.55, 33.55, 33.55, 33.55, 33.55, 32.92, 32.92, 32.92, 32.92, 32.92, 31.1, 31.1, 31.1, 31.1, 31.1, 30.11, 30.11, 30.11, 30.11, 30.11, 30.21, 30.21, 30.21, 30.21, 30.21, 30.39, 30.39, 30.39, 30.39, 30.39, 30.47, 30.47, 30.47, 30.47, 30.47, 30.63, 30.63, 30.63, 30.63, 30.63, 30.8, 30.8, 30.8, 30.8, 30.8, 31.03, 31.03, 31.03, 31.03, 31.03, 30.98, 30.98, 30.98, 30.98, 30.98, 30.75, 30.75, 30.75, 30.75, 30.75, 30.97, 30.97, 30.97, 30.97, 30.97, 31.12, 31.12, 31.12, 31.12, 31.12, 30.95, 30.95, 30.95, 30.95, 30.95, 31.1, 31.1, 31.1, 31.1, 31.1, 31.34, 31.34, 31.34, 31.34, 31.34, 31.45, 31.45, 31.45, 31.45, 31.45, 31.53, 31.53, 31.53, 31.53, 31.53, 31.66, 31.66, 31.66, 31.66, 31.66, 31.73, 31.73, 31.73, 31.73, 31.73, 31.46, 31.46, 31.46, 31.46, 31.46, 31.36, 31.36, 31.36, 31.36, 31.36, 30.91, 30.91, 30.91, 30.91, 30.91, 30.9, 30.9, 30.9, 30.9, 30.9, 30.94, 30.94, 30.94, 30.94, 30.94, 30.98, 30.98, 30.98, 30.98, 30.98, 31.09, 31.09, 31.09, 31.09, 31.09, 31.3, 31.3, 31.3, 31.3, 31.3, 31.17, 31.17, 31.17, 31.17, 31.17, 30.74, 30.74, 30.74, 30.74, 30.74, 30.45, 30.45, 30.45, 30.45, 30.45, 29.74, 29.74, 29.74, 29.74, 29.74, 29.6, 29.6, 29.6, 29.6, 29.6, 29.52, 29.52, 29.52, 29.52, 29.52, 29.51, 29.51, 29.51, 29.51, 29.51, 29.51, 29.51, 29.51, 29.51, 29.51, 29.65, 29.65, 29.65, 29.65, 29.65, 29.67, 29.67, 29.67, 29.67, 29.67, 29.62, 29.62, 29.62, 29.62, 29.62, 29.55, 29.55, 29.55, 29.55, 29.55, 29.51, 29.51, 29.51, 29.51, 29.51, 29.68, 29.68, 29.68, 29.68, 29.68, 29.8, 29.8, 29.8, 29.8, 29.8, 29.89, 29.89, 29.89, 29.89, 29.89, 29.97, 29.97, 29.97, 29.97, 29.97, 30.01, 30.01, 30.01, 30.01, 30.01, 30.04, 30.04, 30.04, 30.04, 30.04, 30.13]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 562 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717984815 --> 1717985435
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.38, 0.38, 0.38, 0.38, 0.38, 0.23, 0.23, 0.23, 0.23, 0.23, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.18, 0.18, 0.18, 0.18, 0.18, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.31, 0.31, 0.31, 0.31, 0.31, 0.33, 0.33, 0.33, 0.33, 0.33, 0.44, 0.44, 0.44, 0.44, 0.44, 0.34, 0.34, 0.34, 0.34, 0.34, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.23, 0.23, 0.23, 0.23, 0.23, 0.16, 0.16, 0.16, 0.16, 0.16, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.11, 0.11, 0.11, 0.11, 0.11, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.22, 0.22, 0.22, 0.22, 0.22, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.23, 0.23, 0.23, 0.23, 0.23, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.3, 0.3, 0.3, 0.3, 0.3, 0.59, 0.59, 0.59, 0.59, 0.59, 0.51, 0.51, 0.51, 0.51, 0.51, 0.39, 0.39, 0.39, 0.39, 0.39, 0.29, 0.29, 0.29, 0.29, 0.29, 0.28, 0.28, 0.28, 0.28, 0.28, 0.34, 0.34, 0.34, 0.34, 0.34, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.27, 0.27, 0.27, 0.27, 0.27, 0.19, 0.19, 0.19, 0.19, 0.19, 0.25, 0.25, 0.25, 0.25, 0.25, 0.1, 0.1, 0.1, 0.1, 0.1, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.1]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 562 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717984815 --> 1717985435
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0]
                    
Loading

@ochafik ochafik marked this pull request as ready for review April 28, 2024 14:35
@ochafik
Copy link
Collaborator Author

ochafik commented Apr 28, 2024

should the token decoding cache be on the llama_context

@ejones done :-)

llama.cpp Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
Co-authored-by: Clint Herron <[email protected]>
llama.cpp Outdated
@@ -13037,6 +13042,10 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
}
}

if (next_candidates.empty()) {
return rejects;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a small and isolated change, I almost wonder if it shouldn't be pulled out into its own PR so that we can evaluate this performance improvement separate from the other one. As it is, it's difficult to know how much to attribute to each change...?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens to be the first commit on the branch so you can git checkout before / after it and compare performance as follows:

( export COMMON_ARGS=(
    -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    --prompt-cache issue4218.bin
    --grammar-file issue4218.gbnf
    -f issue4218.txt
    -c 3400
  ) && \
  hyperfine --warmup 1 --runs 5 \
    -L branch 98f33bae767dd19e213ef663b22ad99979ca71d7^,98f33bae767dd19e213ef663b22ad99979ca71d7 \
    --setup "\
      git checkout {branch} && \
      make clean && make -j LLAMA_CURL=1 main && \
      rm -f issue4218.bin && \
      ./main ${COMMON_ARGS[*]} -n 1" \
    "BRANCH={branch} \
      ./main ${COMMON_ARGS[*]} -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt" )
show output
Benchmark 1: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      7.970 s ±  0.060 s    [User: 3.829 s, System: 0.292 s]
  Range (min … max):    7.877 s …  8.025 s    5 runs
 
Benchmark 2: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      5.814 s ±  0.037 s    [User: 1.674 s, System: 0.277 s]
  Range (min … max):    5.758 s …  5.857 s    5 runs
 
Summary
  'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt' ran
    1.37 ± 0.01 times faster than 'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt'

It doesn't help w/ all grammars, though.

@HanClinto
Copy link
Collaborator

Overall this looks really good! Nice, clean work!

Do we have any reservations about increasing the memory footprint of the context? The caching code increases memory usage roughly by the vocabulary * 2, I think -- is that a problem?

The early exit in llama_grammar_reject_candidates_for_stack feels like a no-brainer easy win, but I don't know enough about the context and the memory usage goals to know if the caching is acceptable or not.

llama.cpp Outdated
@@ -15714,6 +15730,18 @@ struct llama_context * llama_new_context_with_model(
}
}

// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only reservation that I have is that we're caching these data structures whether there is a grammar or not. These data structures are only used in llama_sample_grammar and llama_grammar_accept_token, so if there is no grammar, then this is wasted time (and perhaps more importantly, memory).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point! A “grammars_enabled” context param might make sense, with a flag to turn it off in server (+ explode when grammar requested), and turned on in main only when grammar or schema flag set. Would save a couple of MBs and a tiny bit of preprocessing. Will add tonight 👌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also perhaps restructure this cache building that happens as a memorization step that happens inside of the first calls to decode_utf8 / llama_token_to_piece, rather than when the context is built.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s what I initially did (and was storing it in the grammar itself, which was pointed out as a not ideal by @ejones ). Lazy init comes with big concurrency concerns (say two slots start working with grammars at the same time in the server: can’t let them both populate the data lazily if it’s in the context)

Copy link
Collaborator Author

@ochafik ochafik May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)

Yeah context is shared. And as @ejones pointed out grammar itself technically could be shared in other contexts.

I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used. I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies in advance for this comment -- this is very long, and very stream-of-consciousness, so please take everything with a grain of salt. However, I want to keep moving forward on this PR, so I figure at least some reply is better than nothing -- even if it's a bit rambly. :)

I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used.

Yeah, that's a tricky question. I wonder if instead of a command-line flag, if the precache step at the end of llama_new_context_with_model should maybe be broken out into a separate call -- something like llama_context_precache_grammar_token_pieces? Then, whoever is initializing the context can also add precaching to it...? calling the API can enable this pre-caching feature if they want to or not, but it takes an extra call...?

So I'm back to wondering if there is a way where we can still break this precache functionality into a separate function. Can we check for the need of it in llama_sample_grammar / llama_grammar_accept_token and call the precaching there as-needed?

if (ctx->token_pieces.empty()) {
    llama_context_precache_grammar_token_pieces(ctx);
}
const auto & piece = ctx->token_pieces.at(token);
...

It's annoying to run that check on every token, but the performance improvements when using a grammar should still win out. And we also remove the extra memory usage in the case of no grammar being used.

But as I'm thinking through this now, I think what you said earlier is the problem of ctx being shared across multiple threads / processes, so at the very least, we would need to add a blocking mutex here. Is that feasible / reasonable? We do a number of other little mutexes scattered throughout the code in other places, so it feels like it wouldn't be awful here...?

So yeah, I'm still on the fence about making it as a separate call that needs to be called manually if one wants the performance boost (maybe with an automatic fallback mechanism in the sampler), or else transparent caching on first call at runtime with a mutex (provided we can implement this without it being too annoying to run this check inside such a tight loop).

I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?

This also sounds not unreasonable, but I don't know how to weigh such things. I know I really like grammar-constrained sampling, but I don't know how popular the feature is overall, and is it worth negatively impacting hyper-resource-constrained usages (such as Raspberry Pis or whatnot) vs. grammars? That's what I'm unable to weigh -- I feel like that's a strategic decision that's a bit above my level.

Copy link

@gabriel-peracio gabriel-peracio May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how popular the feature is overall

User here: I've been finding the grammar feature extremely useful and can't live without it. There are also no substitutes anywhere, the competitor's solutions are worse than llama.cpp's implementation IMO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HanClinto thanks++ for your reply & sorry it took me so long to get back to it, the Spring has been full of distractions 😅

I had somehow written concurrent synchronization off but as you mention it, seems worth exploring, looking now!

@gabriel-peracio same vibe here, realized I couldn't live without it and then that I couldn't stand how slow it was when pushed to its limits haha! It's nearly there :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HanClinto I've moved the caching back to llama_sample_grammar in lazy & mutex-protected form (thanks for the suggestion!). And sent the early exit change separately -> #7370

@mofosyne mofosyne added Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level performance Speed related topics labels May 9, 2024
@ochafik ochafik changed the title grammars: cache decoded token codepoints & early exit in candidates rejection (faster sampling) grammars: cache decoded token codepoints for faster sampling May 18, 2024
@ochafik
Copy link
Collaborator Author

ochafik commented Jun 10, 2024

Confirming this is superseded by #7587 & #7424, no incremental gains from caching the decoded codepoints

@ochafik ochafik closed this Jun 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants