Skip to content

Commit

Permalink
Whisper: fix static cache CI (#35852)
Browse files Browse the repository at this point in the history
* fix

* remove overriden method

* small change
  • Loading branch information
zucchini-nlp authored Jan 30, 2025
1 parent 9725e5b commit 365fecb
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 93 deletions.
17 changes: 11 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,28 @@ def prepare_inputs_for_generation(
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)

# 4. Create missing `position_ids` on the fly
attention_mask = (
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (
attention_mask is not None
and kwargs.get("position_ids") is None
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
and kwargs.get(position_ids_key) is None
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)

# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
for model_input_name in ["position_ids", "token_type_ids"]:
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs["inputs_embeds"] is not None
if model_inputs.get("inputs_embeds") is not None
else model_inputs[input_ids_key].shape[1]
)
model_input = model_input[:, -current_input_length:]
Expand Down Expand Up @@ -469,7 +474,7 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
)
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
model_inputs[attention_mask_key] = attention_mask

# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def _expand_variables_for_generation(
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
set_inputs({"inputs": segment_input, "input_ids": decoder_input_ids, **extra_kwargs})

@staticmethod
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
Expand Down
84 changes: 1 addition & 83 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1)

# embed positions
if input_ids is not None:
Expand Down Expand Up @@ -1806,88 +1806,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
cache_position=None,
**kwargs,
):
# Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time
# this function needs to be touched, let's try to sort out the commonalities between the two and remove the
# overwrite.

decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

if decoder_position_ids is not None:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)

if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]

# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()

if (
isinstance(past_key_values, EncoderDecoderCache)
and (
isinstance(past_key_values.self_attention_cache, StaticCache)
or isinstance(past_key_values.cross_attention_cache, StaticCache)
)
and decoder_attention_mask is not None
and decoder_attention_mask.ndim == 2
):
batch_size, sequence_length = decoder_input_ids.shape

decoder_attention_mask = self.get_decoder()._prepare_4d_causal_attention_mask_with_cache_position(
decoder_attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.self_attention_cache.get_max_cache_shape(),
dtype=self.proj_out.weight.dtype,
device=decoder_input_ids.device,
cache_position=cache_position,
batch_size=batch_size,
)

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"cache_position": cache_position,
}


class WhisperDecoderWrapper(WhisperPreTrainedModel):
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3323,8 +3323,8 @@ def test_tiny_static_generation(self):
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)

# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
Expand Down Expand Up @@ -3379,9 +3379,8 @@ def test_tiny_static_generation_long_form(self):
set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)

# compile the forward pass and assert equivalence
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
Expand Down

0 comments on commit 365fecb

Please sign in to comment.