From 4bfd585da93ee70cec779d3ba5464d3b4318f96b Mon Sep 17 00:00:00 2001 From: Jinzhe Pan <48981407+Eigensystem@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:10:35 +0800 Subject: [PATCH] feat: add support for flash_attn>=2.6.0 (#70) --- benchmark/benchmark_longctx.py | 4 ++++ benchmark/benchmark_longctx_qkvpacked.py | 3 +++ benchmark/benchmark_qkvpacked_func.py | 4 ++++ benchmark/benchmark_ring_func.py | 4 ++++ benchmark/benchmark_varlen_qkvpacked_func.py | 2 ++ scripts/run_gqa.sh | 0 scripts/run_qkvpack_compare.sh | 0 setup.py | 4 ++-- test/test_hybrid_attn.py | 2 ++ test/test_hybrid_qkvpacked_attn.py | 2 ++ test/test_ring_flash_attn_func.py | 2 ++ test/test_ring_flash_attn_varlen_func.py | 2 ++ test/test_stripe_flash_attn_func.py | 2 ++ test/test_ulysses_attn.py | 2 ++ test/test_zigzag_ring_flash_attn_func.py | 2 ++ .../test_zigzag_ring_flash_attn_varlen_func.py | 2 ++ yunchang/hybrid/async_attn_layer.py | 2 ++ yunchang/hybrid/attn_layer.py | 5 +++++ yunchang/ring/ring_flash_attn.py | 16 +++++++++++++++- yunchang/ring/ring_flash_attn_varlen.py | 16 +++++++++++++++- yunchang/ring/stripe_flash_attn.py | 18 +++++++++++++++++- yunchang/ring/zigzag_ring_flash_attn.py | 16 +++++++++++++++- yunchang/ring/zigzag_ring_flash_attn_varlen.py | 16 +++++++++++++++- yunchang/ulysses/attn_layer.py | 2 ++ 24 files changed, 121 insertions(+), 7 deletions(-) mode change 100644 => 100755 scripts/run_gqa.sh mode change 100644 => 100755 scripts/run_qkvpack_compare.sh diff --git a/benchmark/benchmark_longctx.py b/benchmark/benchmark_longctx.py index f6378b6..ef8c075 100644 --- a/benchmark/benchmark_longctx.py +++ b/benchmark/benchmark_longctx.py @@ -157,6 +157,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -171,6 +172,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -194,6 +196,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -215,6 +218,7 @@ def benchmark(num_iter=100, forward_only=True, log=True, profile=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, diff --git a/benchmark/benchmark_longctx_qkvpacked.py b/benchmark/benchmark_longctx_qkvpacked.py index 01d65ec..6abb978 100644 --- a/benchmark/benchmark_longctx_qkvpacked.py +++ b/benchmark/benchmark_longctx_qkvpacked.py @@ -117,6 +117,7 @@ def benchmark(num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -134,6 +135,7 @@ def benchmark(num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -147,6 +149,7 @@ def benchmark(num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, diff --git a/benchmark/benchmark_qkvpacked_func.py b/benchmark/benchmark_qkvpacked_func.py index 429a1c5..719d4f0 100644 --- a/benchmark/benchmark_qkvpacked_func.py +++ b/benchmark/benchmark_qkvpacked_func.py @@ -56,6 +56,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -65,6 +66,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -82,6 +84,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -95,6 +98,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, diff --git a/benchmark/benchmark_ring_func.py b/benchmark/benchmark_ring_func.py index cb0f88f..853d7d2 100644 --- a/benchmark/benchmark_ring_func.py +++ b/benchmark/benchmark_ring_func.py @@ -79,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -90,6 +91,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -109,6 +111,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -126,6 +129,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, diff --git a/benchmark/benchmark_varlen_qkvpacked_func.py b/benchmark/benchmark_varlen_qkvpacked_func.py index 534b794..1f46421 100644 --- a/benchmark/benchmark_varlen_qkvpacked_func.py +++ b/benchmark/benchmark_varlen_qkvpacked_func.py @@ -64,6 +64,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, @@ -78,6 +79,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=False, diff --git a/scripts/run_gqa.sh b/scripts/run_gqa.sh old mode 100644 new mode 100755 diff --git a/scripts/run_qkvpack_compare.sh b/scripts/run_qkvpack_compare.sh old mode 100644 new mode 100755 diff --git a/setup.py b/setup.py index 7d0faf0..e8f3c4f 100644 --- a/setup.py +++ b/setup.py @@ -2,11 +2,11 @@ setup( name="yunchang", - version="0.2", + version="0.3", author="Jiarui Fang, Zilin Zhu, Yang Yu", url="https://github.com/feifeibear/long-context-attention", packages=find_packages(exclude=['test', 'benchmark']), install_requires=[ - 'flash-attn', + 'flash-attn>=2.6.0', ], ) diff --git a/test/test_hybrid_attn.py b/test/test_hybrid_attn.py index 2aafc7b..af9f71c 100644 --- a/test/test_hybrid_attn.py +++ b/test/test_hybrid_attn.py @@ -110,6 +110,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -137,6 +138,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_hybrid_qkvpacked_attn.py b/test/test_hybrid_qkvpacked_attn.py index 07c5e37..8eae3b1 100644 --- a/test/test_hybrid_qkvpacked_attn.py +++ b/test/test_hybrid_qkvpacked_attn.py @@ -111,6 +111,7 @@ def test(ring_impl_type="basic"): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -124,6 +125,7 @@ def test(ring_impl_type="basic"): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_ring_flash_attn_func.py b/test/test_ring_flash_attn_func.py index 3990470..8424f09 100644 --- a/test/test_ring_flash_attn_func.py +++ b/test/test_ring_flash_attn_func.py @@ -73,6 +73,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -88,6 +89,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_ring_flash_attn_varlen_func.py b/test/test_ring_flash_attn_varlen_func.py index 1c79543..46ac5ee 100644 --- a/test/test_ring_flash_attn_varlen_func.py +++ b/test/test_ring_flash_attn_varlen_func.py @@ -99,6 +99,7 @@ def extract_lse(lse, cu_seqlens): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -114,6 +115,7 @@ def extract_lse(lse, cu_seqlens): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_stripe_flash_attn_func.py b/test/test_stripe_flash_attn_func.py index 82421eb..f42a5e9 100644 --- a/test/test_stripe_flash_attn_func.py +++ b/test/test_stripe_flash_attn_func.py @@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_ulysses_attn.py b/test/test_ulysses_attn.py index c81ad3d..b3481e7 100644 --- a/test/test_ulysses_attn.py +++ b/test/test_ulysses_attn.py @@ -93,6 +93,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -119,6 +120,7 @@ def log(msg, a, rank0_only=False): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_zigzag_ring_flash_attn_func.py b/test/test_zigzag_ring_flash_attn_func.py index 66a1bb5..26b863c 100644 --- a/test/test_zigzag_ring_flash_attn_func.py +++ b/test/test_zigzag_ring_flash_attn_func.py @@ -80,6 +80,7 @@ def extract_local(value, rank, world_size, dim=1): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -93,6 +94,7 @@ def extract_local(value, rank, world_size, dim=1): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/test/test_zigzag_ring_flash_attn_varlen_func.py b/test/test_zigzag_ring_flash_attn_varlen_func.py index c329f40..9ab0426 100644 --- a/test/test_zigzag_ring_flash_attn_varlen_func.py +++ b/test/test_zigzag_ring_flash_attn_varlen_func.py @@ -104,6 +104,7 @@ def extract_lse(lse, cu_seqlens): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, @@ -119,6 +120,7 @@ def extract_lse(lse, cu_seqlens): dropout_p=dropout_p, causal=causal, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=deterministic, return_attn_probs=True, diff --git a/yunchang/hybrid/async_attn_layer.py b/yunchang/hybrid/async_attn_layer.py index 29e9a29..9150be4 100644 --- a/yunchang/hybrid/async_attn_layer.py +++ b/yunchang/hybrid/async_attn_layer.py @@ -50,6 +50,7 @@ def forward( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -148,6 +149,7 @@ def forward( softmax_scale=softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, diff --git a/yunchang/hybrid/attn_layer.py b/yunchang/hybrid/attn_layer.py index 833b776..c81feef 100644 --- a/yunchang/hybrid/attn_layer.py +++ b/yunchang/hybrid/attn_layer.py @@ -49,6 +49,7 @@ def forward( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -84,6 +85,7 @@ def forward( softmax_scale=softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, @@ -108,6 +110,7 @@ def forward( softmax_scale=softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, @@ -166,6 +169,7 @@ def forward( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -198,6 +202,7 @@ def forward( softmax_scale=softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs, diff --git a/yunchang/ring/ring_flash_attn.py b/yunchang/ring/ring_flash_attn.py index d531540..ce24188 100644 --- a/yunchang/ring/ring_flash_attn.py +++ b/yunchang/ring/ring_flash_attn.py @@ -13,6 +13,7 @@ def ring_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -38,6 +39,7 @@ def ring_flash_attn_forward( softmax_scale, causal=causal and step == 0, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -65,6 +67,7 @@ def ring_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -101,6 +104,7 @@ def ring_flash_attn_backward( softmax_scale, bwd_causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -145,6 +149,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -165,6 +170,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -174,6 +180,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -194,10 +201,11 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def ring_flash_attn_qkvpacked_func( @@ -206,6 +214,7 @@ def ring_flash_attn_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -219,6 +228,7 @@ def ring_flash_attn_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -233,6 +243,7 @@ def ring_flash_attn_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -246,6 +257,7 @@ def ring_flash_attn_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -261,6 +273,7 @@ def ring_flash_attn_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -274,6 +287,7 @@ def ring_flash_attn_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/yunchang/ring/ring_flash_attn_varlen.py b/yunchang/ring/ring_flash_attn_varlen.py index 01c7364..d63343d 100644 --- a/yunchang/ring/ring_flash_attn_varlen.py +++ b/yunchang/ring/ring_flash_attn_varlen.py @@ -32,6 +32,7 @@ def ring_flash_attn_varlen_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -59,6 +60,7 @@ def ring_flash_attn_varlen_forward( softmax_scale, causal=causal and step == 0, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -92,6 +94,7 @@ def ring_flash_attn_varlen_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -131,6 +134,7 @@ def ring_flash_attn_varlen_backward( softmax_scale, bwd_causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -177,6 +181,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -199,6 +204,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -209,6 +215,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -231,10 +238,11 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def ring_flash_attn_varlen_qkvpacked_func( @@ -245,6 +253,7 @@ def ring_flash_attn_varlen_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -260,6 +269,7 @@ def ring_flash_attn_varlen_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -276,6 +286,7 @@ def ring_flash_attn_varlen_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -291,6 +302,7 @@ def ring_flash_attn_varlen_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -308,6 +320,7 @@ def ring_flash_attn_varlen_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -323,6 +336,7 @@ def ring_flash_attn_varlen_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/yunchang/ring/stripe_flash_attn.py b/yunchang/ring/stripe_flash_attn.py index 5682f37..6991aa5 100644 --- a/yunchang/ring/stripe_flash_attn.py +++ b/yunchang/ring/stripe_flash_attn.py @@ -13,6 +13,7 @@ def stripe_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -41,6 +42,7 @@ def stripe_flash_attn_forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -54,6 +56,7 @@ def stripe_flash_attn_forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -83,6 +86,7 @@ def stripe_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -122,6 +126,7 @@ def stripe_flash_attn_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -144,6 +149,7 @@ def stripe_flash_attn_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -195,6 +201,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -215,6 +222,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -224,6 +232,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -244,10 +253,11 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def stripe_flash_attn_qkvpacked_func( @@ -256,6 +266,7 @@ def stripe_flash_attn_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -269,6 +280,7 @@ def stripe_flash_attn_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -283,6 +295,7 @@ def stripe_flash_attn_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -296,6 +309,7 @@ def stripe_flash_attn_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -311,6 +325,7 @@ def stripe_flash_attn_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -324,6 +339,7 @@ def stripe_flash_attn_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/yunchang/ring/zigzag_ring_flash_attn.py b/yunchang/ring/zigzag_ring_flash_attn.py index c753036..c58d20c 100644 --- a/yunchang/ring/zigzag_ring_flash_attn.py +++ b/yunchang/ring/zigzag_ring_flash_attn.py @@ -13,6 +13,7 @@ def zigzag_ring_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -35,6 +36,7 @@ def forward(q, k, v, causal): softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -86,6 +88,7 @@ def zigzag_ring_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -125,6 +128,7 @@ def backward(dout, q, k, v, out, softmax_lse, causal): softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -188,6 +192,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -208,6 +213,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -217,6 +223,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -237,10 +244,11 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def zigzag_ring_flash_attn_qkvpacked_func( @@ -249,6 +257,7 @@ def zigzag_ring_flash_attn_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -262,6 +271,7 @@ def zigzag_ring_flash_attn_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -276,6 +286,7 @@ def zigzag_ring_flash_attn_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -289,6 +300,7 @@ def zigzag_ring_flash_attn_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -304,6 +316,7 @@ def zigzag_ring_flash_attn_func( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -317,6 +330,7 @@ def zigzag_ring_flash_attn_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/yunchang/ring/zigzag_ring_flash_attn_varlen.py b/yunchang/ring/zigzag_ring_flash_attn_varlen.py index 7db7745..8a244f4 100644 --- a/yunchang/ring/zigzag_ring_flash_attn_varlen.py +++ b/yunchang/ring/zigzag_ring_flash_attn_varlen.py @@ -68,6 +68,7 @@ def zigzag_ring_flash_attn_varlen_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -103,6 +104,7 @@ def forward(q, k, v, causal): softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -166,6 +168,7 @@ def zigzag_ring_flash_attn_varlen_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -217,6 +220,7 @@ def backward(dout, q, k, v, out, softmax_lse, causal): softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -281,6 +285,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -307,6 +312,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -326,6 +332,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -357,10 +364,11 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def zigzag_ring_flash_attn_varlen_qkvpacked_func( @@ -371,6 +379,7 @@ def zigzag_ring_flash_attn_varlen_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -386,6 +395,7 @@ def zigzag_ring_flash_attn_varlen_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -402,6 +412,7 @@ def zigzag_ring_flash_attn_varlen_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -417,6 +428,7 @@ def zigzag_ring_flash_attn_varlen_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -434,6 +446,7 @@ def zigzag_ring_flash_attn_varlen_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -449,6 +462,7 @@ def zigzag_ring_flash_attn_varlen_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/yunchang/ulysses/attn_layer.py b/yunchang/ulysses/attn_layer.py index 03f0543..51ea44d 100644 --- a/yunchang/ulysses/attn_layer.py +++ b/yunchang/ulysses/attn_layer.py @@ -76,6 +76,7 @@ def forward( softmax_scale=None, causal=False, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -114,6 +115,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=return_attn_probs,