Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2Model StaticCache support #35761

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

Conversation

poedator
Copy link
Contributor

@poedator poedator commented Jan 18, 2025

I copied _update_causal_mask() and _prepare_4d_causal_attention_mask_with_cache_position() from LlamaModel

some tests are still failing:

  1. tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_custom_4d_attention_mask
  2. test_modeling_vision_encoder_decoder.py::VIT2GPT2Test::test_save_and_load_from_pretrained

both may be linked to attention implementations. So far I was enable to figure out the reasons for failures. I'd appreciate advice or help from the maintainers.

cc: @gante

@poedator poedator mentioned this pull request Jan 18, 2025
@poedator poedator force-pushed the gpt_static branch 2 times, most recently from 278bcf7 to dedb154 Compare January 18, 2025 13:30
@poedator poedator marked this pull request as ready for review January 18, 2025 17:43
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
Not entirely sure it's worth adding as GPT2 is a super small model, not super optimized anymore, and fairly old so the amount of work is a bit high...

Let's make sure we test cross attetnion path with kv cache as I am not even sure it was supported before

Comment on lines +265 to +266
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the big issue with this is that we are breking backward compatibility for people who use layer_past. We need to deprecate layer_past!

Copy link
Contributor Author

@poedator poedator Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added @deprecate_kwarg decorator to this forward and 3 more (incl in GPT2Model). Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.

Comment on lines 830 to 859
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
return_legacy_cache = False
if use_cache:
if past_key_values is not None:
if isinstance(past_key_values, Cache):
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, Cache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.49.0. "
"You should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
if self.config.add_cross_attention:
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
return_legacy_cache = True
logger.warning_once(
"Passing `use_cache=True` and `past_key_values=None` will is produce cache output in legacy format. "
"This behavior is deprecated and will be changed in Transformers v4.49.0. "
"To obtain output past_key_values as `Cache` instance you should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
if self.config.add_cross_attention:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
else:
past_key_values = DynamicCache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the look of it, we are adding quite a complex code, which I am not super fan of.
Let's go with this for now, but would be nice to have a single warning to just say the one or the other is deprecated. This as is is not super readable and you have too many code pathes, when you should have:
past_key value is None -> create DynamicCache
past_key_value is not None -> convert to Dynamic cache (not even sure that cross attention cache was even supported)
add_cross_attention -> create EncodeDDecoderCache with past_key_value and a new dynamic cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified this logic according to our outline

@poedator
Copy link
Contributor Author

poedator commented Jan 25, 2025

Not entirely sure it's worth adding ...

I agree in principle, but my friends use it for Tortoise text-to-speech and intend to compile it (with modifications for static shapes) to accelerate.

Let's make sure we test cross attention path with kv cache as I am not even sure it was supported before

I made effort to patch the cross-attention parts of the code as well. The relevant tests seem to pass

@poedator
Copy link
Contributor Author

poedator commented Jan 28, 2025

@Rocketknight1 @ArthurZucker , could you, please, approve the remaining checks workflows ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@poedator poedator changed the title [WiP] GPT2Model StaticCache support GPT2Model StaticCache support Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants