Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

roll back fmha/common.py #5

Merged
merged 3 commits into from
Feb 29, 2024
Merged

roll back fmha/common.py #5

merged 3 commits into from
Feb 29, 2024

Conversation

tenpercent
Copy link
Collaborator

@tenpercent tenpercent commented Feb 28, 2024

Addressing

The current interface of MHA is that H has to always match between qkv. If you want to do GQA - e.g. one kv-head for every n q-heads, you have to send 5D inputs. (Thus we're forcing the user to be very explicit.) Do we really want to relax that rule in this PR?

And also merging test_mqa_forward into test_mqa_decoding as suggested in

To follow up on @bottler's question about @rocm_only - it'd be better for fmha.ck.FwOp to be covered by the generic test_forward and test_mqa_decoding. Then we don't need a separate test function (and eventually won't need @rocm_only after all such cases are refactored)

Not sure if this should be blocking the merge or can be done as a follow-up

... so users are forced to provide rank-5 inputs for mqa/gqa
@qianfengz
Copy link
Collaborator

qianfengz commented Feb 29, 2024

@tenpercent

  • The current ck.Fw.Op together with its underlying ck-tiled implementation is able to support mqa/gqa, even though the supported input tensors are 4-D
  • But the ref_attention in test_mem_eff_attention.py is not able to handle mqa/gqa with 4-D inputs, so that is why ref_mqa_attention is added
  • test_mqa_forward is added for explicitly verified those functions

@qianfengz
Copy link
Collaborator

Since we added too many scripts for this function, we can just remove them, and I will keep the scripts in private for testing/verification.

@qianfengz qianfengz merged commit 99947ff into dev_upstream Feb 29, 2024
3 checks passed
@tenpercent tenpercent deleted the roll-back-fmha-common branch March 28, 2024 18:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants