diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index f4eed360..37583f91 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -36,9 +36,9 @@ def parse_op_args(args: List[str]): parser.add_argument( "--embedding-dim", type=int, - default=2048, help="specify embedding dim, embedding dim = n_heads * head_dim", ) + parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") parser.add_argument("--causal", action="store_true", help="enable causal") return parser.parse_args(args) @@ -56,6 +56,7 @@ def __init__( self.BATCH = args.batch self.SEQ_LEN = args.seq_len self.embedding_dim = args.embedding_dim + self.H = args.n_heads self.D_HEAD = args.d_head self.causal = args.causal # We always turn on causal for backward @@ -65,6 +66,11 @@ def __init__( self.requires_grad = not self.tb_args.mode == "fwd_no_grad" self.sm_scale = 1.3 + if self.embedding_dim and self.H != self.embedding_dim // self.D_HEAD: + raise ValueError( + f"embedding_dim {self.embedding_dim} is inconsistent with n_heads {self.H}. embedding_dim = n_heads * d_head " + ) + def colfax_preprocess(self, q, k, v): # colfax expects q,k: BATCH, N_CTX, H, D_HEAD and v: BATCH, D_HEAD, H, N_CTX # passed-in: BATCH, H, N_CTX, D_HEAD @@ -133,10 +139,6 @@ def triton_flash_v2_tma( triton_q, triton_k, triton_v, self.causal, self.sm_scale, "tma" ) - def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: - H = self.embedding_dim // self.D_HEAD - return (self.BATCH, self.N_CTX, H, self.D_HEAD) - def get_input_iter(self) -> Generator: # The non-fp8 FA varies N_CTX and fixes other variables. Let's do the same for fp8. # The autotune config only depends on N_CTX in OSS Triton tutorial. @@ -144,9 +146,9 @@ def get_input_iter(self) -> Generator: BATCH = self.BATCH D_HEAD = self.D_HEAD SEQ_LEN_LOG2 = 7 + H = self.H def get_ctx_vals(): - H = self.embedding_dim // D_HEAD if self.SEQ_LEN: yield (BATCH, H, self.SEQ_LEN, self.D_HEAD) return