diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 51f4a5dda2..254d5f6187 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -35,6 +35,8 @@ Models that have been validated: - [meta-llama/Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct) - [tiiuae/falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm) - [google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224) + - [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) + ### Inference with BF16 @@ -122,6 +124,24 @@ python3 run_pipeline.py \ --sdp_on_bf16 ``` +To run Qwen/Qwen2-VL-2B-Instruct inference, use the following command: + +```bash +python3 run_pipeline.py \ + --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ + --use_hpu_graphs \ + --bf16 +``` + +To run Qwen/Qwen2-VL-7B-Instruct inference, use the following command: + +```bash +python3 run_pipeline.py \ + --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ + --use_hpu_graphs \ + --bf16 +``` + ### Inference with FP8 Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 44eb8d575a..bcde0025d7 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -199,7 +199,7 @@ def main(): config = AutoConfig.from_pretrained(args.model_name_or_path) model_type = config.model_type - if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]: + if args.image_path is None and model_type in ["llava", "idefics2", "mllama", "qwen2_vl"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "paligemma": args.image_path = [ @@ -210,7 +210,7 @@ def main(): "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]: + if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma", "qwen2_vl"]: processor = AutoProcessor.from_pretrained(args.model_name_or_path) if args.prompt is None: if processor.chat_template is not None: @@ -289,13 +289,19 @@ def main(): generator = pipeline( "image-to-text", model=args.model_name_or_path, + config=args.model_name_or_path, + tokenizer=args.model_name_or_path, + image_processor=args.model_name_or_path, torch_dtype=model_dtype, device="hpu", ) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - generator.model = wrap_in_hpu_graph(generator.model) + if "Qwen2-VL" in args.model_name_or_path: + # only wrap language model part + generator.model.model = wrap_in_hpu_graph(generator.model.model) + else: + generator.model = wrap_in_hpu_graph(generator.model) if "falcon-11B-vlm" in args.model_name_or_path: # WA falcon vlm issue that image_token_id == embed size. @@ -310,7 +316,7 @@ def main(): "limit_hpu_graphs": args.limit_hpu_graphs, } - if args.sdp_on_bf16: + if args.sdp_on_bf16 and "Llama-3.2-11B-Vision-Instruct" in args.model_name_or_path: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if args.use_kv_cache: @@ -321,7 +327,7 @@ def main(): htcore.hpu_initialize(generator.model) # delete once pipeline integrate AutoProcessor as preprocess engine - if model_type in ["idefics2", "mllama", "paligemma"]: + if model_type in ["idefics2", "mllama", "paligemma", "qwen2_vl"]: from transformers.image_utils import load_image def preprocess(self, image, prompt=None, timeout=None): diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 68b445c1b2..4a8f173804 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -109,6 +109,7 @@ "qwen2_moe", "xglm", "whisper", + "qwen2_vl", "paligemma", "idefics2", "mllama", @@ -860,7 +861,11 @@ def _prepare_cache_for_generation( # Use tuples by default (.i.e. legacy format). else: - return + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) @torch.no_grad() def generate( diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index ee092ecff9..9ff165e565 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -28,6 +28,7 @@ gaudi_StoppingCriteriaList_call, ) from .models import ( + GAUDI_QWEN2_VL_ATTENTION_CLASSES, GAUDI_WHISPER_ATTENTION_CLASSES, BaichuanConfig, BaichuanForCausalLM, @@ -140,6 +141,10 @@ GaudiQwen2MoeForCausalLM, GaudiQwen2MoeMLP, GaudiQwen2MoeModel, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, GaudiStableLmAttention, GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, @@ -684,6 +689,15 @@ def adapt_transformers_to_gaudi(): transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES + # Optimization for Qwen2-VL on Gaudi + transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES = GAUDI_QWEN2_VL_ATTENTION_CLASSES + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLDecoderLayer = GaudiQwen2VLDecoderLayer + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention = GaudiQwen2VLSdpaAttention + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLModel = GaudiQwen2VLModel + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration = ( + GaudiQwen2VLForConditionalGeneration + ) + # Optimization for mllama on Gaudi transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 2a5e685942..74e9a0e75b 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -247,6 +247,13 @@ gaudi_qwen2moe_block_sparse_moe_forward, gaudi_qwen2moe_rmsnorm_forward, ) +from .qwen2_vl import ( + GAUDI_QWEN2_VL_ATTENTION_CLASSES, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, +) from .seamless_m4t import ( gaudi_SeamlessM4TAttention_forward, gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, diff --git a/optimum/habana/transformers/models/qwen2_vl/__init__.py b/optimum/habana/transformers/models/qwen2_vl/__init__.py new file mode 100644 index 0000000000..ed4b3ed8a5 --- /dev/null +++ b/optimum/habana/transformers/models/qwen2_vl/__init__.py @@ -0,0 +1,7 @@ +from .modeling_qwen2_vl import ( + GAUDI_QWEN2_VL_ATTENTION_CLASSES, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, +) diff --git a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py new file mode 100644 index 0000000000..ed3c9f6551 --- /dev/null +++ b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLAttention, + Qwen2VLCausalLMOutputWithPast, + Qwen2VLDecoderLayer, + Qwen2VLFlashAttention2, + Qwen2VLForConditionalGeneration, + Qwen2VLModel, + Qwen2VLSdpaAttention, + apply_multimodal_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GaudiQwen2VLSdpaAttention(Qwen2VLSdpaAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + The only difference are: + - add new args token_idx + - optimize KV cache + """ + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + if token_idx is not None: + if 0 <= self.layer_idx < len(past_key_value.key_cache): + past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) + past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Now only support the default GaudiQwen2VLSdpaAttention +GAUDI_QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": GaudiQwen2VLSdpaAttention, +} + + +class GaudiQwen2VLDecoderLayer(Qwen2VLDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GaudiQwen2VLModel(Qwen2VLModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + if position_ids is None and input_ids is not None: + position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + **kwargs, + ) + + if past_key_values: + input_ids = input_ids[:, token_idx - 1].unsqueeze(-1) + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + "token_idx": token_idx, + } + ) + return model_inputs diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index c73d4d0565..fe584d597e 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -23,6 +23,8 @@ ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 18.974541922240313), ("tiiuae/falcon-11B-vlm", 1, 23.69260849957278), + ("Qwen/Qwen2-VL-2B-Instruct", 1, 13.319195250861407), + ("Qwen/Qwen2-VL-7B-Instruct", 1, 7.491897460793705), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),