Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow n_heads specification in fp8 attention bench #139

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def parse_op_args(args: List[str]):
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="specify embedding dim, embedding dim = n_heads * head_dim",
)
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
parser.add_argument("--causal", action="store_true", help="enable causal")
return parser.parse_args(args)
Expand All @@ -56,6 +56,7 @@ def __init__(
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.embedding_dim = args.embedding_dim
self.H = args.n_heads
self.D_HEAD = args.d_head
self.causal = args.causal
# We always turn on causal for backward
Expand All @@ -65,6 +66,11 @@ def __init__(
self.requires_grad = not self.tb_args.mode == "fwd_no_grad"
self.sm_scale = 1.3

if self.embedding_dim and self.H != self.embedding_dim // self.D_HEAD:
raise ValueError(
f"embedding_dim {self.embedding_dim} is inconsistent with n_heads {self.H}. embedding_dim = n_heads * d_head "
)

def colfax_preprocess(self, q, k, v):
# colfax expects q,k: BATCH, N_CTX, H, D_HEAD and v: BATCH, D_HEAD, H, N_CTX
# passed-in: BATCH, H, N_CTX, D_HEAD
Expand Down Expand Up @@ -133,20 +139,16 @@ def triton_flash_v2_tma(
triton_q, triton_k, triton_v, self.causal, self.sm_scale, "tma"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
H = self.embedding_dim // self.D_HEAD
return (self.BATCH, self.N_CTX, H, self.D_HEAD)

def get_input_iter(self) -> Generator:
# The non-fp8 FA varies N_CTX and fixes other variables. Let's do the same for fp8.
# The autotune config only depends on N_CTX in OSS Triton tutorial.

BATCH = self.BATCH
D_HEAD = self.D_HEAD
SEQ_LEN_LOG2 = 7
H = self.H

def get_ctx_vals():
H = self.embedding_dim // D_HEAD
if self.SEQ_LEN:
yield (BATCH, H, self.SEQ_LEN, self.D_HEAD)
return
Expand Down
Loading