Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Qwen2_5_VL] Fix dtype mismatch in FA2 by removing forced float() cast in rotary embeddings #35966

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sockeye44
Copy link

What does this PR do?

This PR fixes a FlashAttention assertion error that occurs in the Qwen2_5_VL vision blocks when a query tensor is forcibly cast to float32, but the cos/sin embedding tensors remain in a lower precision (bf16 or fp16). This mismatch triggers the error during training runs on mixed (text-only + text-image) batches:

AssertionError: Input and cos/sin must have the same dtype, got torch.float32 and torch.bfloat16

Changes:

  • Removed forced .float() cast in apply_rotary_pos_emb_flashatt so that the query tensor and the cos/sin tensors share the same dtype:
- tensor_ = tensor.float()
+ tensor_ = tensor

Why:

  • FlashAttention requires all inputs to apply_rotary_emb to have identical dtypes
  • Removing the forced cast ensures that the Qwen2_5_VL vision path works correctly in mixed or lower-precision settings, preventing upstream assertion failures and allowing memory-efficient training in bf16 or fp16
  • Verified the fix with seeded runs in mixed precision (bf16 and fp16) settings; tested both text-only and mixed text-image batches to ensure correct behavior

@sockeye44
Copy link
Author

@amyeroberts Hi, could you please review this PR?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @sockeye44 ! Thanks for opening a PR. Can you share a small reproducer, I can't reproduce it with the following. I believe the cos/sin are usually in full precision when doing RoPE

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2.5-VL-7B-Instruct",
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto",
)
text = processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], return_tensors="pt").to(model.device)
 output = model.generate(**inputs, max_new_tokens=30)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants