Skip to content

Commit

Permalink
Causal benchmarks with limited inputs for backwards
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Groenenboom committed May 3, 2024
1 parent 1a323b9 commit 375fa6c
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,14 +1464,24 @@ def varlen_benchmark_configs():
]
return configs

def nonvarlen_backward_benchmark_configs():
configs=[(16, 16, 16, 1024, 1024),
(8, 16, 16, 2048, 2048),
(4, 16, 16, 4096, 4096),
(2, 16, 16, 8192, 8192),
(1, 16, 16, 16384, 16384),
(2, 48, 48, 1024, 1024),
]
return configs

def run_benchmark(custom):

args = parse_args()
dtype = arg_to_torch_dtype[args.dtype]
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
head_size = 128 if not args.d else args.d
mode = 'fwd'
mode = args.direction
x_names=['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
causal = args.causal
varlen = args.varlen
Expand All @@ -1481,6 +1491,8 @@ def run_benchmark(custom):
else:
if varlen:
x_vals_list = varlen_benchmark_configs()
elif mode == 'bwd':
x_vals_list = nonvarlen_backward_benchmark_configs()
else:
x_vals_list = nonvarlen_benchmark_configs()
print_time = args.return_time
Expand Down Expand Up @@ -1515,11 +1527,7 @@ def bench_flash_attention(
# else:
# bias = None
bias = None

# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True


flops_per_matmul = 0
if varlen:
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype)
Expand Down Expand Up @@ -1567,6 +1575,7 @@ def parse_args():
parser.add_argument("-causal", action='store_true', default=False)
parser.add_argument("-varlen", action='store_true', default=False)
parser.add_argument("-dtype", default='fp16')
parser.add_argument("-direction", default='fwd')
parser.add_argument("-return_time", action='store_true', default=False)
return parser.parse_args()

Expand Down

0 comments on commit 375fa6c

Please sign in to comment.