[Qwen2_5_VL] Fix dtype mismatch in FA2 by removing forced float() cast in rotary embeddings #35966
+1
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR fixes a FlashAttention assertion error that occurs in the Qwen2_5_VL vision blocks when a query tensor is forcibly cast to
float32
, but thecos
/sin
embedding tensors remain in a lower precision (bf16
orfp16
). This mismatch triggers the error during training runs on mixed (text-only + text-image) batches:Changes:
.float()
cast inapply_rotary_pos_emb_flashatt
so that the query tensor and thecos
/sin
tensors share the same dtype:Why:
apply_rotary_emb
to have identical dtypesbf16
orfp16