Skip to content

Commit

Permalink
Update benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Jan 27, 2024
1 parent 760a0a2 commit 29dcf4c
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions benchmarks/triton/benchmark_abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-

import torch
import triton
from torch.nn import functional as F

from fla.ops.abc import chunk_abc
from fla.ops.gla import chunk_gla
from fla.ops.retention import chunk_retention

try:
from flash_attn import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False


@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['seq_len'],
# different possible values for `x_name`
x_vals=[128 * 2 ** i for i in range(0, 8)],
# argument name whose value corresponds to a different line in the plot
line_arg='provider',
# possible values for `line_arg``
line_vals=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
# label name for the lines
line_names=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
# line styles
styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
('cyan', ':'), ('yellow', 'dotted'), ('black', ':')],
ylabel="Execution Time (ms)", # label name for the y-axis
# name for the plot. Used also as a file name for saving the plot.
plot_name="Performance",
args={},
)
)
def benchmark(seq_len, provider):
device = 'cuda'
dtype = torch.bfloat16
requires_grad = True
batch_size, n_heads, d_head, n_slots = 16, 8, 128, 64

q = torch.randn(batch_size, n_heads, seq_len, d_head, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(batch_size, n_heads, seq_len, d_head, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(batch_size, n_heads, seq_len, d_head, device=device, requires_grad=requires_grad, dtype=dtype)
g = F.logsigmoid(torch.randn(batch_size, n_heads, seq_len, d_head, device=device, dtype=dtype))
g = g.clamp_min(-5).requires_grad_(requires_grad)
sk = torch.randn(batch_size, n_heads, seq_len, n_slots, device=device, requires_grad=requires_grad, dtype=dtype)
sv = torch.randn(batch_size, n_heads, seq_len, n_slots, device=device, requires_grad=requires_grad, dtype=dtype)

do = torch.ones_like(v, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
if provider == 'abc':
results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, sk, sv), quantiles=quantiles)
elif provider == 'gla':
results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g, None), quantiles=quantiles)
elif provider == 'abc_bwd':
results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, sk, sv).backward(do), quantiles=quantiles)
elif provider == 'gla_bwd':
results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g, None).backward(do), quantiles=quantiles)
elif provider == 'retention_bwd':
results = triton.testing.do_bench(lambda: chunk_retention(q, k, v).backward(do), quantiles=quantiles)
elif provider == 'flash_bwd':
results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
return results


if __name__ == '__main__':
benchmark.run(print_data=True, save_path='.')

0 comments on commit 29dcf4c

Please sign in to comment.