Skip to content

Commit

Permalink
Disable tma on AMD by default on fp8_gemm_rowwise
Browse files Browse the repository at this point in the history
Summary: Updates fp8_gemm_rowwise to set the TMA default based on if we are running with Nvidia or AMD.

Reviewed By: danzimm

Differential Revision: D69680948

fbshipit-source-id: 4628b9a0a0a30d06e8f695f05e8810fdcd291110
  • Loading branch information
Nick Riasanovsky authored and facebook-github-bot committed Feb 19, 2025
1 parent 592d65f commit 79d2949
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,22 @@ def parse_args(args: List[str]) -> argparse.Namespace:
parser.add_argument(
"--no_fp8_fast_accum", dest="fp8_fast_accum", action="store_false"
)
parser.add_argument("--no_use_tma", dest="use_tma", action="store_false")
parser.add_argument(
"--no_use_tma", dest="use_tma", default=None, action="store_false"
)
parser.add_argument("--use_tma", dest="use_tma", action="store_true")
parser.add_argument(
"--no_use_persistent",
dest="no_use_persistent",
action="store_true",
)
parser.add_argument("--warp_specialization", action="store_true")
parsed_args = parser.parse_args(args)
if parsed_args.use_tma is None:
# Default to True for CUDA, False for ROCm
parsed_args.use_tma = True if torch.version.hip is None else False
if torch.version.hip is not None and parsed_args.use_tma:
raise RuntimeError("TMA is not supported on ROCm")
return parsed_args


Expand Down

0 comments on commit 79d2949

Please sign in to comment.