Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(FA): QKV not being casted to target_dtype for FA with dpo lora #35834

Merged
merged 1 commit into from
Jan 28, 2025

Conversation

NanoCode012
Copy link
Contributor

@NanoCode012 NanoCode012 commented Jan 22, 2025

What does this PR do?

In DPO LORA, the QKV states aren't being casted to target_dtype raising RuntimeError: FlashAttention only support fp16 and bf16 data type.

This issue was created from the refactor in #35342

The prior code checks if query.dtype == torch.float32 and always casting it to fp16.
The current code checks if query.dtype == torch.float32, followed by fa_peft_integration_check check if value.dtype == torch.float32 before casting to target_dtype.

From my debugging,

TARGET DTYPE torch.bfloat16                                                                                                               
QUERY DTYPE torch.float32                                                                                                                 
KEY DTYPE torch.float32                                                                                                                   
VALUE DTYPE torch.bfloat16 

Since value.dtype is bf16, the cast does not occur. This PR fixes the check to use the same query.dtype check.

How has this been tested?

I ran Axolotl CI, found the related commit to cause the error, tracked down this issue, and verified the fix solves it.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Cyrilvallez

Comment on lines +212 to 213
input_dtype = query.dtype
if input_dtype == torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_dtype = query.dtype
if input_dtype == torch.float32:
if any([module.dtype == torch.float32 for module in [query, key, value]]):

might it be better to check all the modules? or is there a performance penalty in doing this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might not be a huge loss, but does not feel great

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super sorry if we broke this, waiting for some feedback but sounds good

Comment on lines +212 to 213
input_dtype = query.dtype
if input_dtype == torch.float32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might not be a huge loss, but does not feel great

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @NanoCode012! Super sorry for the delay! LGTM to me, thanks for fixing! Indeed it was not matching with what we are doing in flash_attention.py 🤗 And until now we were always only checking the query so it should be fine as is, I also don't like checking them all so much! But happy to revisit if this happen to be an issue in the future!

@ArthurZucker ArthurZucker merged commit 478c4f2 into huggingface:main Jan 28, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants