Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: new Cache format in decoder-only models #31421

Merged
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
183cd66
draft bart with new cache
zucchini-nlp Jun 14, 2024
4578bca
add cache for decoder-only models
zucchini-nlp Jun 14, 2024
9505ca4
revert utils
zucchini-nlp Jun 14, 2024
2ab28f3
modify docstring
zucchini-nlp Jun 14, 2024
5fe4e9e
revert bart
zucchini-nlp Jun 14, 2024
09413c3
minor fixes
zucchini-nlp Jun 14, 2024
3c27604
fix copies (not related)
zucchini-nlp Jun 14, 2024
350acc5
revert tests
zucchini-nlp Jun 14, 2024
c0adf10
remove enc-dec related code
zucchini-nlp Jun 17, 2024
c18b177
remove bloom
zucchini-nlp Jun 17, 2024
582f289
remove opt (enc-dec)
zucchini-nlp Jun 17, 2024
3141a71
Merge remote-tracking branch 'upstream/main' into dynamic_cache_decod…
zucchini-nlp Jun 17, 2024
33d54b4
update docstring
zucchini-nlp Jun 18, 2024
dd05e6b
git, codegen, gpt_neo, gpt_neox, gpj
zucchini-nlp Jun 18, 2024
cb878d5
clean up
zucchini-nlp Jun 19, 2024
0588791
copied from statements
zucchini-nlp Jun 19, 2024
a27b47c
revert
zucchini-nlp Jun 19, 2024
1abcf30
tmp
zucchini-nlp Jun 19, 2024
00ed88c
update warning msg
zucchini-nlp Jun 20, 2024
6c3b3aa
forgot git
zucchini-nlp Jun 20, 2024
fd5eeab
add more flags
zucchini-nlp Jun 21, 2024
e233f29
run-slow git,codegen,gpt_neo,gpt_neox,gpj
zucchini-nlp Jun 21, 2024
356d578
add cache flag to VLMs
zucchini-nlp Jul 9, 2024
c906670
remove files
zucchini-nlp Jul 9, 2024
08d9e6f
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 9, 2024
56c05b2
style
zucchini-nlp Jul 9, 2024
8510810
video LLMs also need a flag
zucchini-nlp Jul 9, 2024
cebb55d
style
zucchini-nlp Jul 9, 2024
8fd9dd1
llava will go in another PR
zucchini-nlp Jul 26, 2024
4b9ced1
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 26, 2024
aea219b
style
zucchini-nlp Jul 26, 2024
4991863
[run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics
zucchini-nlp Jul 26, 2024
ec306a2
Update src/transformers/models/gpt_neo/modeling_gpt_neo.py
zucchini-nlp Jul 30, 2024
cf793b7
copy from
zucchini-nlp Jul 30, 2024
c92409c
deprecate until v4.45 and warn if not training
zucchini-nlp Jul 30, 2024
c2b97e4
nit
zucchini-nlp Jul 30, 2024
35b60de
fix test
zucchini-nlp Jul 30, 2024
d2fca9a
test static cache
zucchini-nlp Aug 2, 2024
0933350
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Aug 2, 2024
42349d4
add more tests and fix models
zucchini-nlp Aug 2, 2024
45c3a1b
fix copies
zucchini-nlp Aug 2, 2024
5f22616
return sliding window mask
zucchini-nlp Aug 2, 2024
f5af6a2
run slow tests & fix + codestyle
zucchini-nlp Aug 6, 2024
21b45c5
one more falcon fix for alibi
zucchini-nlp Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from transformers import AutoTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

ARTICLE_TO_SUMMARIZE = (
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors="pt").to("cuda:0")

# Generate Summary
summary_ids = model.generate(**inputs, num_beams=1, do_sample=False, max_new_tokens=30, use_cache=False)
out = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(out)
318 changes: 225 additions & 93 deletions src/transformers/models/codegen/modeling_codegen.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinds of changes are from fix copies, and are not related at all to the PR. But let's leave it here as it's anyway related to code-consistency in the library

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
369 changes: 236 additions & 133 deletions src/transformers/models/falcon/modeling_falcon.py
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

144 changes: 83 additions & 61 deletions src/transformers/models/git/modeling_git.py
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

318 changes: 224 additions & 94 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py

Large diffs are not rendered by default.

371 changes: 252 additions & 119 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py

Large diffs are not rendered by default.

357 changes: 247 additions & 110 deletions src/transformers/models/gptj/modeling_gptj.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
Empty file added update_llava.py
Empty file.
Loading