Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: new Cache format in decoder-only models #31421

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
183cd66
draft bart with new cache
zucchini-nlp Jun 14, 2024
4578bca
add cache for decoder-only models
zucchini-nlp Jun 14, 2024
9505ca4
revert utils
zucchini-nlp Jun 14, 2024
2ab28f3
modify docstring
zucchini-nlp Jun 14, 2024
5fe4e9e
revert bart
zucchini-nlp Jun 14, 2024
09413c3
minor fixes
zucchini-nlp Jun 14, 2024
3c27604
fix copies (not related)
zucchini-nlp Jun 14, 2024
350acc5
revert tests
zucchini-nlp Jun 14, 2024
c0adf10
remove enc-dec related code
zucchini-nlp Jun 17, 2024
c18b177
remove bloom
zucchini-nlp Jun 17, 2024
582f289
remove opt (enc-dec)
zucchini-nlp Jun 17, 2024
3141a71
Merge remote-tracking branch 'upstream/main' into dynamic_cache_decod…
zucchini-nlp Jun 17, 2024
33d54b4
update docstring
zucchini-nlp Jun 18, 2024
dd05e6b
git, codegen, gpt_neo, gpt_neox, gpj
zucchini-nlp Jun 18, 2024
cb878d5
clean up
zucchini-nlp Jun 19, 2024
0588791
copied from statements
zucchini-nlp Jun 19, 2024
a27b47c
revert
zucchini-nlp Jun 19, 2024
1abcf30
tmp
zucchini-nlp Jun 19, 2024
00ed88c
update warning msg
zucchini-nlp Jun 20, 2024
6c3b3aa
forgot git
zucchini-nlp Jun 20, 2024
fd5eeab
add more flags
zucchini-nlp Jun 21, 2024
e233f29
run-slow git,codegen,gpt_neo,gpt_neox,gpj
zucchini-nlp Jun 21, 2024
356d578
add cache flag to VLMs
zucchini-nlp Jul 9, 2024
c906670
remove files
zucchini-nlp Jul 9, 2024
08d9e6f
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 9, 2024
56c05b2
style
zucchini-nlp Jul 9, 2024
8510810
video LLMs also need a flag
zucchini-nlp Jul 9, 2024
cebb55d
style
zucchini-nlp Jul 9, 2024
8fd9dd1
llava will go in another PR
zucchini-nlp Jul 26, 2024
4b9ced1
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 26, 2024
aea219b
style
zucchini-nlp Jul 26, 2024
4991863
[run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics
zucchini-nlp Jul 26, 2024
ec306a2
Update src/transformers/models/gpt_neo/modeling_gpt_neo.py
zucchini-nlp Jul 30, 2024
cf793b7
copy from
zucchini-nlp Jul 30, 2024
c92409c
deprecate until v4.45 and warn if not training
zucchini-nlp Jul 30, 2024
c2b97e4
nit
zucchini-nlp Jul 30, 2024
35b60de
fix test
zucchini-nlp Jul 30, 2024
d2fca9a
test static cache
zucchini-nlp Aug 2, 2024
0933350
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Aug 2, 2024
42349d4
add more tests and fix models
zucchini-nlp Aug 2, 2024
45c3a1b
fix copies
zucchini-nlp Aug 2, 2024
5f22616
return sliding window mask
zucchini-nlp Aug 2, 2024
f5af6a2
run slow tests & fix + codestyle
zucchini-nlp Aug 6, 2024
21b45c5
one more falcon fix for alibi
zucchini-nlp Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,29 +367,20 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
past_key_values[idx][1][:, :maximum_length, :],
if isinstance(past_key_values[0], DynamicCache):
past_key_values[0].crop(maximum_length)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
else:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
past_key_values = tuple(new_past)
# gptbigcode is special
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
Expand All @@ -401,13 +392,12 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(maximum_length)

elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][0][..., :maximum_length, :],
past_key_values[idx][1][..., :maximum_length, :],
)
)
past_key_values = tuple(new_past)
Expand Down
178 changes: 108 additions & 70 deletions src/transformers/models/bloom/modeling_bloom.py

Large diffs are not rendered by default.

91 changes: 57 additions & 34 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
Expand Down Expand Up @@ -57,7 +58,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten


class CodeGenAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()

max_positions = config.max_position_embeddings
Expand All @@ -71,6 +72,13 @@ def __init__(self, config):

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
Expand Down Expand Up @@ -150,7 +158,7 @@ def _attn(
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -200,18 +208,11 @@ def forward(
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)

# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (key.to(hidden_states.dtype), value)
else:
present = None
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_dim}
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)

# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
Expand All @@ -220,7 +221,7 @@ def forward(
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)

Expand Down Expand Up @@ -250,17 +251,17 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(nn.Module):
# Ignore copy
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = CodeGenAttention(config)
self.attn = CodeGenAttention(config, layer_idx)
self.mlp = CodeGenMLP(inner_dim, config)

def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -303,6 +304,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget

_supports_quantized_cache = True
_supports_static_cache = True

(if appropriate, on this and other models)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes, forgot about _supports_quantized_cache flag. But I am not sure about the "static_cache" flag. Will it imply that models are fullgraph compilable, cause we didn't test it

in any case I'll make a following PR to check fullgraph compile and if it works as-is add the tests in each modeling

Copy link
Member

@gante gante Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

We should add a mixin test for fullgraph compilation when _supports_static_cache = True

Moreover, after this PR, I think _supports_static_cache and _supports_cache_class will mean the same thing -- another property to check and rectify (= remove _supports_static_cache) in a follow-up PR

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will test and add a mixin test (one lightweight and one slow maybe) after this PR is merged. For now I added the flags and tested via running generation tests

I think _supports_static_cache and _supports_cache_class will mean the same thing

Btw, GIT will be an exception which supports cache class but not static cache as it has some special attn mask preparation steps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we could also just add the class to the set _support_quantized_cache={""} in cache_utils, we don't pollute this here, and we can directly get all classes that support quantized / static etc.
-> better to auto build the doc, better in general to serparate cache stuff from modeling specific


def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -374,6 +376,10 @@ def _init_weights(self, module):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand All @@ -397,7 +403,7 @@ def __init__(self, config):
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)

Expand All @@ -421,7 +427,7 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -457,11 +463,12 @@ def forward(
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])

if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
past_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_length = past_key_values.get_seq_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super ugly IMO.
I would really love for us to just no add this code that ended up being copy pasted everywhere because we were lazy to update when porting new models.

Let's use cache positions. And let's already deprecate the casting and legacy cache etc, to make sure we only do this for one revision!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I wanted to add cache_position as part of StaticCache support but it can be added in this PR also. Wondering about deprecation cycle for Cache class, afaik there were no deprecation warning about that before so we would still have to keep the ugly hack and add a warning message?

I'm pro of totally getting rid of old cache, if it doesn't break BC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see that Llama on main had its preparation modified already assuming that the inputs is always Cache object. Oke, then it makes sense to get rid of legacy_cache in forward also

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it was not done before but this PR IMO is a good time to do it! So keep this code, but add a warning saying, we are automatically converting your cache from tuple to dynamic cache class (SHould only be triggered outside generate because generate should already pass a cache class!)

Yep it makes sense but let's not be too brutal in case some people still use it! We give them one release until we totally remove it!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oke, added cache position and deprecation warnings to all models from this PR. I'll add the same deprecation warning to all models that already support cache class in another PR. This one is ready for review, tests are passing on my end!


if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
Expand Down Expand Up @@ -514,10 +521,10 @@ def forward(
)
use_cache = False

presents = () if use_cache else None
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand All @@ -535,7 +542,7 @@ def forward(
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
layer_past=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
Expand All @@ -545,7 +552,7 @@ def forward(

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
next_decoder_cache = outputs[1]

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
Expand All @@ -557,12 +564,18 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
Expand Down Expand Up @@ -593,9 +606,12 @@ def set_output_embeddings(self, new_embeddings):

def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
attention_mask = kwargs.get("attention_mask", None)
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
past_length = cache_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_length()

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
Expand All @@ -608,7 +624,14 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]

attention_mask = kwargs.get("attention_mask", None)
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
Expand All @@ -619,7 +642,7 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
Expand All @@ -644,7 +667,7 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinds of changes are from fix copies, and are not related at all to the PR. But let's leave it here as it's anyway related to code-consistency in the library

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
Expand Down
Loading
Loading