reduce for kqmax_new_j is unnecessary #1032
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
using this patch,the performance will increase about 1%-2% ,testing in A800
test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf
i am do some trick to letting nb=1,2,3,7 will using flash_attn_vec_ext_f16(because A800 is capable for wmma) just for the eval performance
origin:
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.51 us/run - 4.19 MFLOP/run - 364.41 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.38 us/run - 8.39 MFLOP/run - 583.51 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.18 us/run - 12.58 MFLOP/run - 594.02 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 34.17 us/run - 29.36 MFLOP/run - 859.30 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 19.03 us/run - 8.39 MFLOP/run - 440.74 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 25.04 us/run - 16.78 MFLOP/run - 670.07 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.84 us/run - 25.17 MFLOP/run - 683.06 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 67.79 us/run - 58.72 MFLOP/run - 866.21 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.68 us/run - 4.19 MFLOP/run - 330.79 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.16 us/run - 8.39 MFLOP/run - 519.10 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 22.11 us/run - 12.58 MFLOP/run - 569.02 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 37.08 us/run - 29.36 MFLOP/run - 791.89 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.47 us/run - 8.39 MFLOP/run - 509.47 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.93 us/run - 16.78 MFLOP/run - 765.03 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 35.48 us/run - 25.17 MFLOP/run - 709.22 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 60.34 us/run - 58.72 MFLOP/run - 973.20 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.23 us/run - 4.19 MFLOP/run - 373.53 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 13.85 us/run - 8.39 MFLOP/run - 605.60 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.40 us/run - 12.58 MFLOP/run - 616.89 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.79 us/run - 29.36 MFLOP/run - 719.80 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.64 us/run - 8.39 MFLOP/run - 450.01 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.73 us/run - 16.78 MFLOP/run - 707.06 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.49 us/run - 25.17 MFLOP/run - 729.75 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.55 us/run - 58.72 MFLOP/run - 882.32 GFLOPS
apply this patch:
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.14 us/run - 4.19 MFLOP/run - 376.67 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.02 us/run - 8.39 MFLOP/run - 598.41 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.66 us/run - 12.58 MFLOP/run - 609.01 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 33.68 us/run - 29.36 MFLOP/run - 871.69 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.82 us/run - 8.39 MFLOP/run - 445.67 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 24.57 us/run - 16.78 MFLOP/run - 682.88 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.41 us/run - 25.17 MFLOP/run - 691.19 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.20 us/run - 58.72 MFLOP/run - 887.06 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.70 us/run - 4.19 MFLOP/run - 330.27 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 15.85 us/run - 8.39 MFLOP/run - 529.23 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.69 us/run - 12.58 MFLOP/run - 580.05 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 36.73 us/run - 29.36 MFLOP/run - 799.45 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.19 us/run - 8.39 MFLOP/run - 518.10 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.47 us/run - 16.78 MFLOP/run - 781.60 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.47 us/run - 25.17 MFLOP/run - 730.00 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 59.55 us/run - 58.72 MFLOP/run - 986.15 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 98304 runs - 10.93 us/run - 4.19 MFLOP/run - 383.64 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 13.52 us/run - 8.39 MFLOP/run - 620.46 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 19.85 us/run - 12.58 MFLOP/run - 633.98 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.10 us/run - 29.36 MFLOP/run - 732.12 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.36 us/run - 8.39 MFLOP/run - 456.96 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.40 us/run - 16.78 MFLOP/run - 716.87 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 33.66 us/run - 25.17 MFLOP/run - 747.69 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 65.55 us/run - 58.72 MFLOP/run - 895.82 GFLOPS