Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in apply_rotary_pos_emb_flashatt: in Qwen2-5-VL #36046

Closed

Conversation

DeepWaved
Copy link
Contributor

What does this PR do?

This PR addresses an issue that arises when using Flash Attention 2 in conjunction with other mixed-precision training frameworks like DeepSpeed, leading to an AssertionError:
"Input and cos/sin must have the same dtype, got torch.float32 and torch.float16."
The root cause is that the DeepSpeed framework implicitly converts inv_freq (L:118) to float16, while the rope computation process in both SDPA and the default case—specifically in the apply_rotary_pos_emb_vision function—involves converting it to float32. To prevent this bug from occurring, a similar conversion process has been added to apply_rotary_pos_emb_flashatt, ensuring consistent dtype handling and resolving the issue.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Test

import deepspeed
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

def train():
    MODEL_PATH = "Qwen2-5-VL-3B-Instruct"
    
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype="auto",
        device_map="cuda:0",
        attn_implementation="flash_attention_2",
    )
    processor = AutoProcessor.from_pretrained(MODEL_PATH)

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
                    "max_pixels":128*128,
                },
                {"type": "text", "text": "Describe this image."},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    model_engine, optimizer, _, _ = deepspeed.initialize(
        args=None,
        model=model,
        model_parameters=model.parameters(),
        config_params="ds.json",
    )
    model_engine.train()
    outputs = model_engine(**inputs)

if __name__ == "__main__":
    train()

{
    "fp16": {
        "enabled": "true"
    },
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 1e-5,
            "betas": [
                0.9,
                0.98
            ],
            "eps": 1e-8
        }
    },
    "zero_optimization": {
        "stage": 2
    },
    "train_batch_size": 1,
    "train_micro_batch_size_per_gpu": 1
}

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the PR. I believe it is duplicate of #35966, but that one got stale so let's merge this PR

Thanks for clear explanation, I agree with the fix. Can you run make-fix-copies to make CI green?

@DeepWaved DeepWaved closed this Feb 5, 2025
@DeepWaved DeepWaved deleted the qwen2-5-vl-rope-dtype-mismatch branch February 5, 2025 12:52
@zucchini-nlp
Copy link
Member

Hehe, I think the PR got closed because make fix-copies deleted the changes. For models like qwen2-5-vl, the changes have to be done in modular_xxx.py files and the fix-copies will automatically it in model files :)

Feel free to re-open again

@DeepWaved
Copy link
Contributor Author

Hehe, I think the PR got closed because make fix-copies deleted the changes. For models like qwen2-5-vl, the changes have to be done in modular_xxx.py files and the fix-copies will automatically it in model files :)

Feel free to re-open again

Thank you for your guidance. I will try to recommit. I have a question: after I finish modifying modular_qwen2_5_vl.py and run make fix-copies, the modeling_qwen2_5_vl.py was successfully modified. But additional unintended changes were also introduced, such as modeling_zamaba2.py. How can I avoid this? Thanks a lot.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 6, 2025

@DeepWaved that's weird, with latest main branch there should be no files reformatted usually, can you try to rebase?

UPDATE: oh just found out there are a few files that will be reformatted even with main, so it's okay yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants