-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(FA): QKV not being casted to target_dtype for FA with dpo lora #35834
Conversation
input_dtype = query.dtype | ||
if input_dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_dtype = query.dtype | |
if input_dtype == torch.float32: | |
if any([module.dtype == torch.float32 for module in [query, key, value]]): |
might it be better to check all the modules? or is there a performance penalty in doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might not be a huge loss, but does not feel great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super sorry if we broke this, waiting for some feedback but sounds good
input_dtype = query.dtype | ||
if input_dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might not be a huge loss, but does not feel great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @NanoCode012! Super sorry for the delay! LGTM to me, thanks for fixing! Indeed it was not matching with what we are doing in flash_attention.py
🤗 And until now we were always only checking the query
so it should be fine as is, I also don't like checking them all so much! But happy to revisit if this happen to be an issue in the future!
What does this PR do?
In DPO LORA, the QKV states aren't being casted to
target_dtype
raisingRuntimeError: FlashAttention only support fp16 and bf16 data type
.This issue was created from the refactor in #35342
The prior code checks if
query.dtype == torch.float32
and always casting it to fp16.The current code checks if
query.dtype == torch.float32
, followed byfa_peft_integration_check
checkif value.dtype == torch.float32
before casting totarget_dtype
.From my debugging,
Since
value.dtype
isbf16
, the cast does not occur. This PR fixes the check to use the samequery.dtype
check.How has this been tested?
I ran Axolotl CI, found the related commit to cause the error, tracked down this issue, and verified the fix solves it.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Cyrilvallez