-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT2Model
StaticCache support
#35761
base: main
Are you sure you want to change the base?
Conversation
278bcf7
to
dedb154
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Not entirely sure it's worth adding as GPT2 is a super small model, not super optimized anymore, and fairly old so the amount of work is a bit high...
Let's make sure we test cross attetnion path with kv cache as I am not even sure it was supported before
past_key_value: Optional[Cache] = None, | ||
cache_position: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the big issue with this is that we are breking backward compatibility for people who use layer_past
. We need to deprecate layer_past
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added @deprecate_kwarg
decorator to this forward and 3 more (incl in GPT2Model
). Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder | ||
return_legacy_cache = False | ||
if use_cache: | ||
if past_key_values is not None: | ||
if isinstance(past_key_values, Cache): | ||
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): | ||
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) | ||
elif not isinstance(past_key_values, Cache): | ||
return_legacy_cache = True | ||
logger.warning_once( | ||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.49.0. " | ||
"You should pass an instance of `Cache` instead, e.g. " | ||
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." | ||
) | ||
if self.config.add_cross_attention: | ||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) | ||
else: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
elif past_key_values is None: | ||
return_legacy_cache = True | ||
logger.warning_once( | ||
"Passing `use_cache=True` and `past_key_values=None` will is produce cache output in legacy format. " | ||
"This behavior is deprecated and will be changed in Transformers v4.49.0. " | ||
"To obtain output past_key_values as `Cache` instance you should pass an instance of `Cache` instead, e.g. " | ||
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." | ||
) | ||
if self.config.add_cross_attention: | ||
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) | ||
else: | ||
past_key_values = DynamicCache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the look of it, we are adding quite a complex code, which I am not super fan of.
Let's go with this for now, but would be nice to have a single warning to just say the one or the other is deprecated. This as is is not super readable and you have too many code pathes, when you should have:
past_key value is None -> create DynamicCache
past_key_value is not None -> convert to Dynamic cache (not even sure that cross attention cache was even supported)
add_cross_attention -> create EncodeDDecoderCache with past_key_value and a new dynamic cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified this logic according to our outline
I agree in principle, but my friends use it for Tortoise text-to-speech and intend to compile it (with modifications for static shapes) to accelerate.
I made effort to patch the cross-attention parts of the code as well. The relevant tests seem to pass |
@Rocketknight1 @ArthurZucker , could you, please, approve the remaining checks workflows ? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
GPT2Model
StaticCache supportGPT2Model
StaticCache support
I copied _update_causal_mask() and _prepare_4d_causal_attention_mask_with_cache_position() from
LlamaModel
some tests are still failing:
both may be linked to attention implementations. So far I was enable to figure out the reasons for failures. I'd appreciate advice or help from the maintainers.
cc: @gante