Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kernel cache miss and add RDNA configs #246

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

hyoon1
Copy link

@hyoon1 hyoon1 commented Oct 25, 2024

  • added Navi configurations (Related PR: add RDNA Config triton#640)
  • resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0

@@ -795,8 +880,8 @@ def forward(
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
MAX_SEQLENS_Q=0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason to zero seq lens?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below attention fwd kernel is called when we run the model with vllm:

However, MAX_SEQLENS_Q/K differs every step, and it occurs different key value and compilation for the triton kernel each step, which leads to the performance degradation.
https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/runtime/jit.py#L620
https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/runtime/jit.py#L660

Currently, VARLEN is always set, and MAX_SEQLENS_Q/K are not used in this case when you look at the kernel in vllm.

Therefore, we just set MAX_SEQLENS_Q/K as a fixed value when we call the kernel for a workaround.

@@ -207,103 +209,186 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_gfx_version():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like not used, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Removed it.

).arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908')


def is_rdna():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably worth to use:

def is_navi() -> bool:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per my knowledge AMD has two lines of HW for vllm: MI and Navi. So not navi should work better for future generations of MIs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@maleksan85
Copy link

As per @gshtras we need to merge into develop branch instead of main for now. Please correct.

return None


def is_hip():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this functionality is implemented in a cross-architecture fashion in the platform/rocm.py and its superclasses

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@hyoon1 hyoon1 changed the base branch from main to develop November 13, 2024 19:19
@hyoon1
Copy link
Author

hyoon1 commented Nov 13, 2024

@maleksan85 @gshtras
Thanks for the comments, I agree with your suggestions. However, I have a few concerns regarding this matter and I'm seeking advice on how to proceed. First, I believe that triton_flash_attention.py in vllm is essentially a copy of the file from ROCm/triton. The modifications in this pull request directly apply the changes from the ROCm/triton repository's pull request #640. While it might be fine to change the relevant functions only in vllm, there is a risk of misalignment later on.

Secondly, our team is using the v0.6.2+rocm release, and I understand that functions like is_navi() are not supported in that version. Implementing them would require significant modifications. Therefore, maintaining backward compatibility is also a concern.

Given these considerations, I would greatly appreciate your advice on how to proceed with the modifications.

@gshtras
Copy link
Collaborator

gshtras commented Nov 13, 2024

As for your last point, whatever changes will be made here will not have any effect on the previous tags, so v0.6.2+rocm will not get affected.
That's true, our kernel is a snapshot of the one you mentioned, taken in the beginning of 2024. Our attempts to catch up in the past resulted in performance regressions on various models in different configs, which was noted to the team, but I don't believe it was thoroughly investigated. We may take another round of this experiment, but regardless, I think utilizing an existing infrastructure APIs that vllm provides is better here, if nothing else, for uniformity and avoiding code duplication.

@hyoon1 hyoon1 force-pushed the fix_max_seq branch 3 times, most recently from 3f81ad2 to 4cc77c2 Compare November 19, 2024 07:05
@@ -207,103 +209,149 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_cdna_autotune_configs():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure that those new commits will not decrease performance on MI. If so, what models did you tested?
cc @gshtras

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested on Navi31. I thought it was tested by triton team for other models because they modified configs for better performance. https://github.com/ROCm/triton/blob/db2ca015159c6592c30a6bfcd77b9cc540063a8e/python/perf-kernels/flash-attention.py#L334

Beside those configs for autoconfig, I believe fixing MAX_SEQLENS_Q/K to 0 will increase the performance for MI as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tested chatglm2-6b, qwen-14b-chat, baichuan2-13b, llama-2-70b-chat, glm-4-9b-chat, qwen1.5-72b-chat-gptq, etc. on Navi31, w/o this change, triton-based FA2 has no positive perf lifting; while with this change, triton-based FA2 shows 2-5% gain. (and by debugging, it is confirmed that triton FA2 kernel cache is missed). We believe this should also provide positive impact on MI, especially during early triton kernel cache built-up period.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed in the chat to separate things for MI from this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored autotune configs for MI series

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyoon1, could you please make this change only applicable to Navi? I will ask engineers in China to confirm the perf gain on Navi32 (although such cache misses issue has no dependencies on what GPU used). Thanks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. MI will use original configs for autotune.

- added Navi configurations (Related PR: ROCm/triton#640)
- resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants