Skip to content

Commit

Permalink
feat: add support for flash_attn>=2.6.0 (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eigensystem authored Aug 26, 2024
1 parent 478c3a2 commit 4bfd585
Show file tree
Hide file tree
Showing 24 changed files with 121 additions and 7 deletions.
4 changes: 4 additions & 0 deletions benchmark/benchmark_longctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -171,6 +172,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -194,6 +196,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -215,6 +218,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand Down
3 changes: 3 additions & 0 deletions benchmark/benchmark_longctx_qkvpacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -134,6 +135,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -147,6 +149,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand Down
4 changes: 4 additions & 0 deletions benchmark/benchmark_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -65,6 +66,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -82,6 +84,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -95,6 +98,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand Down
4 changes: 4 additions & 0 deletions benchmark/benchmark_ring_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -90,6 +91,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -109,6 +111,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -126,6 +129,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand Down
2 changes: 2 additions & 0 deletions benchmark/benchmark_varlen_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand All @@ -78,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
Expand Down
Empty file modified scripts/run_gqa.sh
100644 → 100755
Empty file.
Empty file modified scripts/run_qkvpack_compare.sh
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

setup(
name="yunchang",
version="0.2",
version="0.3",
author="Jiarui Fang, Zilin Zhu, Yang Yu",
url="https://github.com/feifeibear/long-context-attention",
packages=find_packages(exclude=['test', 'benchmark']),
install_requires=[
'flash-attn',
'flash-attn>=2.6.0',
],
)
2 changes: 2 additions & 0 deletions test/test_hybrid_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down Expand Up @@ -137,6 +138,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_hybrid_qkvpacked_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test(ring_impl_type="basic"):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -124,6 +125,7 @@ def test(ring_impl_type="basic"):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_ring_flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -88,6 +89,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_ring_flash_attn_varlen_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def extract_lse(lse, cu_seqlens):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -114,6 +115,7 @@ def extract_lse(lse, cu_seqlens):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_stripe_flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_ulysses_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -119,6 +120,7 @@ def log(msg, a, rank0_only=False):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_zigzag_ring_flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions test/test_zigzag_ring_flash_attn_varlen_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def extract_lse(lse, cu_seqlens):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand All @@ -119,6 +120,7 @@ def extract_lse(lse, cu_seqlens):
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
Expand Down
2 changes: 2 additions & 0 deletions yunchang/hybrid/async_attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def forward(
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
Expand Down Expand Up @@ -148,6 +149,7 @@ def forward(
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
Expand Down
5 changes: 5 additions & 0 deletions yunchang/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def forward(
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
Expand Down Expand Up @@ -84,6 +85,7 @@ def forward(
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
Expand All @@ -108,6 +110,7 @@ def forward(
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
Expand Down Expand Up @@ -166,6 +169,7 @@ def forward(
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
Expand Down Expand Up @@ -198,6 +202,7 @@ def forward(
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
Expand Down
Loading

0 comments on commit 4bfd585

Please sign in to comment.