Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the KV projection history when generating #76

Closed
wants to merge 1 commit into from

Conversation

dfyz
Copy link

@dfyz dfyz commented Jan 22, 2023

This PR is a mostly failed attempt to fix issue #95 from the minGPT repo.

The idea is to save the results of key and value projections in each self-attention layer for previously generated tokens. With saved projections, you can essentially convert all matrix-matrix multiplications at every generation step into matrix-vector multiplications, since you now only need to apply linear transformations to the very last token. This is a pretty standard optimization technique for sequential Transformer generation. For example, I think Huggingface calls the cached projections past_key_values.

The only positive impact of this PR is a tremendous speed-up of CPU generation. E.g., on my MacBook Air with Apple M1:

python3 sample.py [...] --device=cpu --dtype=float32 --num_samples=1

GPT-2 GPT-2-medium
Before 81.632 seconds 235.111 seconds
After 10.966 seconds 30.414 seconds

(I used the following hacky patch to generate sentences directly from pretrained GPT models and print generation times)

Unfortunately, with A100, the speed-up is a rounding error even for GPT-XL:

python sample.py [...] --device=cuda:7 --dtype=float32 --num_samples=1

GPT-2 GPT-2-medium GPT-2-large GPT-2-XL
Before 6.118 seconds 10.498 seconds 15.180 seconds 19.982 seconds
After 6.017 seconds 10.076 seconds 14.923 seconds 19.149 seconds

(for some reason, both bfloat16 and float16 are slower on A100 than float32, even if I don't make any code changes, so I didn't bother measuring them)

Even if the speed-up was more pronounced, I don't think cached generation is worth it, for two reasons:

  1. It clutters the forward pass quite a bit. Maybe there is a nicer way to implement this in PyTorch, but right now I resort to packing the new projections from each layer into a list, then unpacking them back again at the next step, which looks really ugly.
  2. More importantly, this optimization doesn't work when max_new_tokens exceeds block_size, since then the absolute positions of the previous tokens change, and the cached KV history is no longer valid. I guess you could sidestep this with something like rotary positional embeddings, but then you lose the ability to initialize from stock GPT models.

So, this PR is more of a proof-of-concept and should not be merged. Although it might be a good idea to add a comment to GPT.generate() with an explanation why it recomputes the previous tokens from scratch at every step, to prevent anyone else from going down this particular rabbit hole. :)

@pixar0407
Copy link

@karpathy

Thank you for the quality code.

  1. can I ask why you don't use 'past_key_values"[1] ?

  2. your current version and 'past_key_values" version are mathematically different, right?
    because your current version keeps updating previous value and key vectors. And 'past_key_values' version keep using fixed value and key vector(by just concatenating them and not updating the previous key and values).

[1]https://github.com/huggingface/transformers/blob/98d40fed3a4515077163adab9dfd8fb2fccf1267/src/transformers/models/gpt2/modeling_gpt2.py#L317

@dfyz
Copy link
Author

dfyz commented Feb 1, 2023

your current version and 'past_key_values" version are mathematically different, right?

I think they are pretty much identical up to renaming (not intentionally, it's just that it's hard to implement it differently). Unless I've done something really stupid and fail to see it. :)

My version:

def forward(..., past_kv_proj=None):
    if past_kv_proj is not None:
        past_k_proj, past_v_proj = past_kv_proj
        ...
        k = torch.cat((past_k_proj, k), dim=2)
        v = torch.cat((past_v_proj, v), dim=2)
        ...
        present_kv_proj = (k, v) ...

past_{k,v}_proj both have shape (B, nh, history_size, hs), where history_size is the number of tokens generated so far.

The Huggingface version you linked to:

def forward(
    ...
    layer_past: Optional[Tuple[torch.Tensor]] = None,
) ... :
    ...
    if layer_past is not None:
        past_key, past_value = layer_past
        key = torch.cat((past_key, key), dim=-2)
        value = torch.cat((past_value, value), dim=-2)
    ...
        present = (key, value)

past_{key,value} both appear to have shape (batch_size, num_heads, sequence_length, embed_size_per_head), if I'm looking at the right comment, which is basically the same thing as my version.

So I don't really see where I keep updating previous keys and values. My intention was to take previous KV projections and concatenate them with the current projections, exactly as in the HuggingFace implementation.

@dfyz
Copy link
Author

dfyz commented Feb 1, 2023

I guess it's best for me to close this PR, given that the upstream version has diverged from mine (after integrating FlashAttention) and that I never intended this to be merged anyway. This PR is still linked to the original minGPT issue, so anyone interested in implementing cached generation can still find it.

@dfyz dfyz closed this Feb 1, 2023
@pixar0407
Copy link

@dfyz , @karpathy

Thank you for the comment.

your current version and 'past_key_values" version are mathematically different, right?

"your current version" is actually referring to the one that @karpathy implemented in this main branch, not @dfyz 's PR version.

@dfyz
Copy link
Author

dfyz commented Feb 1, 2023

Whoops, sorry, I didn't notice you were referring to Andrej, not me. But if you don't mind me answering anyway: mathematically, the version in the main branch of this repo and the Huggingface version with past_key_values are equivalent, even though the former indeed "keeps updating previous value and key vectors" and the latter "keep using fixed value and key vector". This only affects the time complexity of computing attention, the equations stay exactly the same.

I guess the easiest way to see this would be to run the version from the main branch and print key and value vectors for every layer after every generated token. You'll notice that, once the vectors are generated for a token at some step T, they stay constant for all steps T' > T (modulo floating-point indeterminism). This is exactly why we can cache them with past_key_values and get a speed-up.

More theoretically, you could notice that key and value projections for time step T only depend on:

  • the token embedding of the token generated at step T
  • the position embedding of the token generated at step T
  • the key and value projections generated at steps T' < T

If all of the dependencies don't change when we compute steps T' > T, then the key and value projections for T also do not change.

@jasonphillips
Copy link

jasonphillips commented Apr 15, 2023

Jumping in on this for a moment since I was tempted to implement this same optimization today, then found this PR.

More importantly, this optimization doesn't work when max_new_tokens exceeds block_size, since then the absolute positions of the previous tokens change, and the cached KV history is no longer valid. I guess you could sidestep this with something like rotary positional embeddings, but then you lose the ability to initialize from stock GPT models.

Indeed this would seem to be a problem on one level, since position embeddings are part of the computation for the keys and values... but Huggingface's standard GPT2 also uses absolute position embeddings (which is why they recommend padding on the right of a short sequence rather than left), while still implementing the past_key_values. So I'm left wondering how they get around this problem once the input exceeds the max window and tokens are effectively shifted one position between steps. I don't see any apparent handling for it in their code.

@dfyz
Copy link
Author

dfyz commented Apr 16, 2023

So I'm left wondering how they get around this problem once the input exceeds the max window and tokens are effectively shifted one position between steps

This is a very good question. I thought the answer was "they don't" (i.e., they just throw an exception or something at that point), but it's slightly more complicated than that. This comment explores the possible options.

Option 1 is giving up on caching and generating the next tokens "the old way" once the input is too long. This way you only get a speedup for some prefix of the generation. This could work but IMO would be too complex for an educational codebase such as nanoGPT.

Option 2 is cutting a prefix of the prompt so that more tokens can be generated using the cache (this is what they provide as handle_long_generation='hole' in their pipeline). Doesn't really help if the prompt is much smaller than the number of tokens you want to generate.

Option 3 is the most interesting, and I don't think I fully understand it. Apparently you can keep generating tokens with "wrong" positions, and the generation quality will only degrade slightly? When I tried something like that, the model started outputting pure garbage pretty quickly.

@jasonphillips
Copy link

jasonphillips commented Apr 17, 2023

@dfyz

Excellent find! I was searching for a discussion like that one in their repo.

It seems to me that Option 3 is probably not desirable, as you already noted.

I feel like the optimum solution might be something like Option 2, but I think their naming ("hole") is a bit confusing.

If an option like that was offered as an argument to the model.generate() method, I'd name it differently. That method already silently cuts your context during the loop once you exceed the max window, of course: https://github.com/karpathy/nanoGPT/blob/master/model.py#L352-L353

But instead of losing all speedup from cached keys/values once you exceed the maximum, there could be an option like "shift_by_tokens" that lets you select how far the model will shift the entire window all at once. So if eg. max context length was 100, then once it has reached n = 100 tokens, it would shift by that specified # of tokens instead of by 1 at a time. If shift_by_tokens is 20, then when generating the 101 token it would actually clear 20 tokens from the left of the context (leaving 80), and generate the next 20 using cached keys/values, before doing that again. That would let the user set the trade-off between speedup and loss of context.

Perhaps that makes sense? As you said, the educational nature of this project makes the choice tricky, but exposing the trade-offs of caching seems useful.

@dfyz
Copy link
Author

dfyz commented Apr 17, 2023

Perhaps that makes sense?

This sounds promising, but I am little unclear on how exactly this will work.

As far as I understand, once you compute an embedding for a token by combining a token and a positional embedding, there is no straightforward way to change either without recalculating all layers of the decoder. So, in you hypothetical situation with 100 tokens generated, you have the KV-history that is calculated based on the following (X_i is the token embedding for token #i, T_i is the positional embedding for position #i):

X_0 X_1 ... X_98 X_99
T_0 T_1 ... T_98 T_99

The already-existing cutting of the context essentially changes the history to look like this (since we don't have T_100 and can't incorporate it into the history):

X_1 X_2 ... X_99 X_100
T_0 T_1 ... T_98 T_99

After the cutting, we have to re-calculate everything from scratch, since all token embeddings have changed.

If I'm understanding correctly what you're proposing, after clearing the leftmost 20 tokens, we would have this:

X_20 X_21 ... X_98 X_99 X_100
T_20 T_21 ... T_98 T_99 T_???

Even though you got rid of the leftmost tokens, you still can't change any T_i without recalculating everything. So I don't think you can assign any meaningful T_i to X_100 here. You can't use T_100 because you don't have it, and using anything else breaks the contiguous numbering.

However, it's been a long time since I last looked at the codebase, so it is entirely possible that I'm missing something and/or misunderstanding what you are proposing. Perhaps it's better to just ignore what I'm saying and create a small proof-of-concept branch to see it your idea works.

(and it's ultimately up to Andrej whether he wants to see something like this in his codebase, I don't have any say in this :))

@jasonphillips
Copy link

jasonphillips commented Apr 17, 2023

@dfyz

I may not have written very clearly—I meant that when the cut happens, we would drop the cached keys/values and intentionally recalculate everything from scratch (for the newly cut range of tokens, eg the last 80 that are now our entire context in the example). But then the cached kvs from that step would still be usable for the next 20 steps (or whatever the shift size is).

So when we get to token 101, we'd have something like this

X_20 X_21 ... X_98 X_99
T_0 T_1 ... T_78 T_79 

... which means recomputing all the KVs at this step, but then we get another 19 steps in a row with the new cached values--so we end up gaining cache-based speed up on about 95% of our steps past the token limit (in this contrived example of 20 shift / 100 max), while losing some of the context window size to make it work.

So the idea isn’t to solve the need to recompute everything when generating longer text, but to let the user set a trade off, so that we only recompute from scratch every n_shift steps, rather than every single step once we exceed the window — but with the trade off being that now your context window is shrunk by that much each time we have to shift over.

But again I might have misinterpreted.

@dfyz
Copy link
Author

dfyz commented Apr 17, 2023

Oooooh, now I see, thanks! This is really clever and makes perfect sense. I bet there are a lot of real-life scenarios where you have a model with absolute position embeddings, do not care about the reduced context size all that much, but would appreciate getting a performance boost.

At this point I would ask @karpathy what he thinks about this idea (probably not here, since this is a closed PR, but in the old minGPT issue or a new one in this repo). If it turns out that Andrej doesn't want introducing additional complexity after all, maybe it's even worth trying to propose this to Huggingface?

klei22 pushed a commit to klei22/nanoGPT that referenced this pull request Feb 25, 2024
Add scripts compatible with jsbach midi json files
@efg001
Copy link

efg001 commented Jul 15, 2024

Edit:
Nvm assuming you have max_new_tokens = 500 n_embd = 768
The CPU inference speedup is significant because max_new_tokens < n_embd.

Previous comment:
sorry for digging out this old issue but wanted to check if you did any sanity check on the CPU runtime improvement?

For each transformer block where K and V are needed to be recalculated, attention is also calculated and attention dot product should dominates the runtime over calculating K and V.

Suppose in the best case scenario where the model only consists transformer blocks and attention dot product cost the same amount time as KV calculation, runtime reduction should be only 50% but you are seeing a > 80% runtime reduction when running on CPU.

Maybe I am missing something here.
I looked at the PR but haven't try to reproduce the benchmark test you did.

gkielian added a commit to gkielian/ReaLLMASIC_nanogpt that referenced this pull request Sep 5, 2024
Add scripts compatible with jsbach midi json files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Caching for generation
4 participants