diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index a4a619e63b8c..7b7ecdbd8a4b 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1464,6 +1464,16 @@ def varlen_benchmark_configs(): ] return configs +def nonvarlen_backward_benchmark_configs(): + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + ] + return configs + def run_benchmark(custom): args = parse_args() @@ -1471,7 +1481,7 @@ def run_benchmark(custom): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d - mode = 'fwd' + mode = args.direction x_names=['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.varlen @@ -1481,6 +1491,8 @@ def run_benchmark(custom): else: if varlen: x_vals_list = varlen_benchmark_configs() + elif mode == 'bwd': + x_vals_list = nonvarlen_backward_benchmark_configs() else: x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time @@ -1515,11 +1527,7 @@ def bench_flash_attention( # else: # bias = None bias = None - - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - + flops_per_matmul = 0 if varlen: q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) @@ -1567,6 +1575,7 @@ def parse_args(): parser.add_argument("-causal", action='store_true', default=False) parser.add_argument("-varlen", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') + parser.add_argument("-direction", default='fwd') parser.add_argument("-return_time", action='store_true', default=False) return parser.parse_args()