-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grammars
: cache decoded token codepoints for faster sampling
#6811
Conversation
I've been out of it for a few days, but finally digging back in. GREAT work on all of this, @ochafik !!
I'm still trying to read through and understand this -- I can't claim to have a good handle on this yet. Are you able to summarize why this optimization works?
Because |
@HanClinto welcome back, hope the robotics competition went well!
Good point, reverted / will send separately 👌 (one less thing to check for this PR :-D) |
@HanClinto The code used to decode each token in the context of a potentially incomplete UTF-8 sequence (say, if the first token was a byte token with the I'll add comments about this :-) |
Thanks -- it went really well! The kids all did great! 10 kids with 9 robots in 3 different categories. We took home two special nominations, two gold trophies, a silver, and a bronze. Tried a new thing this year where our team jerseys were custom-printed lab coats for all of the kids, and that was a ton of fun. :D It's a pretty big competition (here's a highlight reel from last year), and the kids are excited for next year. Anyways, enough fluff. That's what happens when a project doesn't have a Discord server -- we clutter up the PRs with such conversations. :D
Sounds great! 👍 I know I've mentioned it a few times, but I have a rather complex SQL grammar that I need to dig out and dust off that would help stretch this stuff a bit more. |
grammars
: cache decoded token codepoints & early exit in candidates rejection (up to 10x faster sampling)grammars
: cache decoded token codepoints & early exit in candidates rejection (faster sampling)
Hey, wanted to drop in quick - love this work @ochafik! Have been following along the discussions from the notifications. Hope to have some time in the next few days to take a deeper look. I took a super quick look and one surface level thought I had was: should the token decoding cache be on the Other than that, I need to have a think about UTF-8 decoding and that multibyte sequences with this. |
Thanks @ejones !
Ahh, I wondered about this, guess I took the lazy route. If we moved this to the context we'd probably want to build it upfront to avoid concurrency issues in
Oh indeed, for speculative decoding I suppose the two models could have (ever so slightly) different tokenizers 😓 llama_context, hold on tight, more things coming your way :-D |
@ejones done :-) |
Co-authored-by: Clint Herron <[email protected]>
llama.cpp
Outdated
@@ -13037,6 +13042,10 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ | |||
} | |||
} | |||
|
|||
if (next_candidates.empty()) { | |||
return rejects; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is such a small and isolated change, I almost wonder if it shouldn't be pulled out into its own PR so that we can evaluate this performance improvement separate from the other one. As it is, it's difficult to know how much to attribute to each change...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It happens to be the first commit on the branch so you can git checkout before / after it and compare performance as follows:
( export COMMON_ARGS=(
-mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
-m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
--prompt-cache issue4218.bin
--grammar-file issue4218.gbnf
-f issue4218.txt
-c 3400
) && \
hyperfine --warmup 1 --runs 5 \
-L branch 98f33bae767dd19e213ef663b22ad99979ca71d7^,98f33bae767dd19e213ef663b22ad99979ca71d7 \
--setup "\
git checkout {branch} && \
make clean && make -j LLAMA_CURL=1 main && \
rm -f issue4218.bin && \
./main ${COMMON_ARGS[*]} -n 1" \
"BRANCH={branch} \
./main ${COMMON_ARGS[*]} -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt" )
show output
Benchmark 1: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^ ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
Time (mean ± σ): 7.970 s ± 0.060 s [User: 3.829 s, System: 0.292 s]
Range (min … max): 7.877 s … 8.025 s 5 runs
Benchmark 2: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7 ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
Time (mean ± σ): 5.814 s ± 0.037 s [User: 1.674 s, System: 0.277 s]
Range (min … max): 5.758 s … 5.857 s 5 runs
Summary
'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7 ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt' ran
1.37 ± 0.01 times faster than 'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^ ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt'
It doesn't help w/ all grammars, though.
Overall this looks really good! Nice, clean work! Do we have any reservations about increasing the memory footprint of the context? The caching code increases memory usage roughly by the vocabulary * 2, I think -- is that a problem? The early exit in |
llama.cpp
Outdated
@@ -15714,6 +15730,18 @@ struct llama_context * llama_new_context_with_model( | |||
} | |||
} | |||
|
|||
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling. | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only reservation that I have is that we're caching these data structures whether there is a grammar or not. These data structures are only used in llama_sample_grammar
and llama_grammar_accept_token
, so if there is no grammar, then this is wasted time (and perhaps more importantly, memory).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point! A “grammars_enabled” context param might make sense, with a flag to turn it off in server (+ explode when grammar requested), and turned on in main only when grammar or schema flag set. Would save a couple of MBs and a tiny bit of preprocessing. Will add tonight 👌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also perhaps restructure this cache building that happens as a memorization step that happens inside of the first calls to decode_utf8 / llama_token_to_piece, rather than when the context is built.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s what I initially did (and was storing it in the grammar itself, which was pointed out as a not ideal by @ejones ). Lazy init comes with big concurrency concerns (say two slots start working with grammars at the same time in the server: can’t let them both populate the data lazily if it’s in the context)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)
Yeah context is shared. And as @ejones pointed out grammar itself technically could be shared in other contexts.
I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used. I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies in advance for this comment -- this is very long, and very stream-of-consciousness, so please take everything with a grain of salt. However, I want to keep moving forward on this PR, so I figure at least some reply is better than nothing -- even if it's a bit rambly. :)
I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used.
Yeah, that's a tricky question. I wonder if instead of a command-line flag, if the precache step at the end of llama_new_context_with_model
should maybe be broken out into a separate call -- something like llama_context_precache_grammar_token_pieces
? Then, whoever is initializing the context can also add precaching to it...? calling the API can enable this pre-caching feature if they want to or not, but it takes an extra call...?
So I'm back to wondering if there is a way where we can still break this precache functionality into a separate function. Can we check for the need of it in llama_sample_grammar
/ llama_grammar_accept_token
and call the precaching there as-needed?
if (ctx->token_pieces.empty()) {
llama_context_precache_grammar_token_pieces(ctx);
}
const auto & piece = ctx->token_pieces.at(token);
...
It's annoying to run that check on every token, but the performance improvements when using a grammar should still win out. And we also remove the extra memory usage in the case of no grammar being used.
But as I'm thinking through this now, I think what you said earlier is the problem of ctx being shared across multiple threads / processes, so at the very least, we would need to add a blocking mutex here. Is that feasible / reasonable? We do a number of other little mutexes scattered throughout the code in other places, so it feels like it wouldn't be awful here...?
So yeah, I'm still on the fence about making it as a separate call that needs to be called manually if one wants the performance boost (maybe with an automatic fallback mechanism in the sampler), or else transparent caching on first call at runtime with a mutex (provided we can implement this without it being too annoying to run this check inside such a tight loop).
I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?
This also sounds not unreasonable, but I don't know how to weigh such things. I know I really like grammar-constrained sampling, but I don't know how popular the feature is overall, and is it worth negatively impacting hyper-resource-constrained usages (such as Raspberry Pis or whatnot) vs. grammars? That's what I'm unable to weigh -- I feel like that's a strategic decision that's a bit above my level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how popular the feature is overall
User here: I've been finding the grammar feature extremely useful and can't live without it. There are also no substitutes anywhere, the competitor's solutions are worse than llama.cpp's implementation IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HanClinto thanks++ for your reply & sorry it took me so long to get back to it, the Spring has been full of distractions 😅
I had somehow written concurrent synchronization off but as you mention it, seems worth exploring, looking now!
@gabriel-peracio same vibe here, realized I couldn't live without it and then that I couldn't stand how slow it was when pushed to its limits haha! It's nearly there :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HanClinto I've moved the caching back to llama_sample_grammar in lazy & mutex-protected form (thanks for the suggestion!). And sent the early exit change separately -> #7370
grammars
: cache decoded token codepoints & early exit in candidates rejection (faster sampling)grammars
: cache decoded token codepoints for faster sampling
Edit: superseded by #7587 & #7424, can't measure significant improvements from caching the decoded codepoints: closing this PR.
Cache token pieces & their decoded codepoints (for common case where there's no partial utf8 sequence prefix) in
llama_context
to speed up grammar sampling (#4218)cc/ @ejones, @HanClinto
Update: while this PR still seems useful, #7424 is bringing the most performance in most cases. I've run some benchmarks that seems to show this PR will add some marginal gains upon it.
This simple example runs 1.9x faster (total time; sampling itself goes from ~16 ms per token to ~1.9 ms per token, a 9x sampling speedup):
Show output
master
:this PR:
Note: removed the following changes from earlier version of this PR
llama_grammar_copy
for speculative decoding (will send this separately)