Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kernel cache miss and add RDNA configs #246

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 139 additions & 58 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import triton
import triton.language as tl

from vllm.utils import is_navi

torch_dtype: tl.constexpr = torch.float16


Expand Down Expand Up @@ -207,103 +209,178 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_cdna_autotune_configs():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure that those new commits will not decrease performance on MI. If so, what models did you tested?
cc @gshtras

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested on Navi31. I thought it was tested by triton team for other models because they modified configs for better performance. https://github.com/ROCm/triton/blob/db2ca015159c6592c30a6bfcd77b9cc540063a8e/python/perf-kernels/flash-attention.py#L334

Beside those configs for autoconfig, I believe fixing MAX_SEQLENS_Q/K to 0 will increase the performance for MI as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tested chatglm2-6b, qwen-14b-chat, baichuan2-13b, llama-2-70b-chat, glm-4-9b-chat, qwen1.5-72b-chat-gptq, etc. on Navi31, w/o this change, triton-based FA2 has no positive perf lifting; while with this change, triton-based FA2 shows 2-5% gain. (and by debugging, it is confirmed that triton FA2 kernel cache is missed). We believe this should also provide positive impact on MI, especially during early triton kernel cache built-up period.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed in the chat to separate things for MI from this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored autotune configs for MI series

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyoon1, could you please make this change only applicable to Navi? I will ask engineers in China to confirm the perf gain on Navi32 (although such cache misses issue has no dependencies on what GPU used). Thanks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. MI will use original configs for autotune.

Copy link
Author

@hyoon1 hyoon1 Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additional chatglm3-6b throughput test result on Navi 32 (16gb)
use triton / num-prompts 512 / max-model-len 512
original: input: 1234.33 toks/s, output: 921.11 toks/s Throughput: 5.31 requests/s, 2544.00 tokens/s
w/ update: input: 1386.34 toks/s, output: 1034.54 toks/s Throughput: 5.96 requests/s, 2856.15 tokens/s

return [
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': True
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 64,
'BLOCK_N': 64,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_rdna_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_autotune_configs():
if is_navi():
return get_rdna_autotune_configs()
else:
return get_cdna_autotune_configs()


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Expand Down Expand Up @@ -773,6 +850,10 @@ def forward(
else:
bias_strides = (0, 0, 0, 0)

if is_navi():
max_seqlens_q = 0
max_seqlens_k = 0

attn_fwd[grid](
q,
k,
Expand Down