Skip to content

Commit

Permalink
Allow n_heads specification in fp8 attention bench
Browse files Browse the repository at this point in the history
Summary:
Currently fp8 attention benchmark only allows specifying the `embedding_dim` and the `h_dimension`, which leads to inconsistencies with the bf16 benchmark.

- Adding support to provide `n_heads` as input and raise `ValueError` incase mismatch between providing inputs.
- Dropping default value for `embedding_dim`

Differential Revision: D68531784
  • Loading branch information
mandroid6 authored and facebook-github-bot committed Jan 23, 2025
1 parent 9dc83c9 commit fdf3a9c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def parse_op_args(args: List[str]):
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="specify embedding dim, embedding dim = n_heads * head_dim",
)
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
parser.add_argument("--causal", action="store_true", help="enable causal")
return parser.parse_args(args)
Expand All @@ -56,6 +56,7 @@ def __init__(
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.embedding_dim = args.embedding_dim
self.H = args.n_heads
self.D_HEAD = args.d_head
self.causal = args.causal
# We always turn on causal for backward
Expand All @@ -65,6 +66,11 @@ def __init__(
self.requires_grad = not self.tb_args.mode == "fwd_no_grad"
self.sm_scale = 1.3

if self.embedding_dim and self.H != self.embedding_dim // self.D_HEAD:
raise ValueError(
f"embedding_dim {self.embedding_dim} is inconsistent with n_heads {self.H}. embedding_dim = n_heads * d_head "
)

def colfax_preprocess(self, q, k, v):
# colfax expects q,k: BATCH, N_CTX, H, D_HEAD and v: BATCH, D_HEAD, H, N_CTX
# passed-in: BATCH, H, N_CTX, D_HEAD
Expand Down Expand Up @@ -133,20 +139,16 @@ def triton_flash_v2_tma(
triton_q, triton_k, triton_v, self.causal, self.sm_scale, "tma"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
H = self.embedding_dim // self.D_HEAD
return (self.BATCH, self.N_CTX, H, self.D_HEAD)

def get_input_iter(self) -> Generator:
# The non-fp8 FA varies N_CTX and fixes other variables. Let's do the same for fp8.
# The autotune config only depends on N_CTX in OSS Triton tutorial.

BATCH = self.BATCH
D_HEAD = self.D_HEAD
SEQ_LEN_LOG2 = 7
H = self.H

def get_ctx_vals():
H = self.embedding_dim // D_HEAD
if self.SEQ_LEN:
yield (BATCH, H, self.SEQ_LEN, self.D_HEAD)
return
Expand Down

0 comments on commit fdf3a9c

Please sign in to comment.