diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 5d2e13c45b97..4650950954c2 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,9 +30,9 @@ namespace mlir::triton { const std::set ENV_VARS = { - "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", - "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", - "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16"}; + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", + "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16", "TRITON_LLVM_FLAG"}; namespace tools { diff --git a/lib/Target/HSACO/HSACOTranslation.cpp b/lib/Target/HSACO/HSACOTranslation.cpp index 69e755ec3f5a..9d5de2514e51 100644 --- a/lib/Target/HSACO/HSACOTranslation.cpp +++ b/lib/Target/HSACO/HSACOTranslation.cpp @@ -135,6 +135,14 @@ std::string generate_amdgcn_assembly(llvm::Module *module, if (machine == nullptr) return ""; + std::string llvm_flag = mlir::triton::tools::getenv("TRITON_LLVM_FLAG"); + if (!llvm_flag.empty()) { + std::vector args; + args.push_back((char *)("triton")); + args.push_back((char *)(llvm_flag.c_str())); + llvm::cl::ParseCommandLineOptions(args.size(), &args[0]); + } + llvm::SmallVector buffer; llvm::legacy::PassManager pass; llvm::raw_svector_ostream stream(buffer); diff --git a/python/perf-kernels/06-fused-attention-fwd-transV.py b/python/perf-kernels/06-fused-attention-fwd-transV.py index 35a6da764746..64eb5e07c532 100644 --- a/python/perf-kernels/06-fused-attention-fwd-transV.py +++ b/python/perf-kernels/06-fused-attention-fwd-transV.py @@ -163,10 +163,9 @@ def forward(ctx, q, k, v, sm_scale): kpack = 1 else: ## D_HEAD = 128 - ## For fp16, pick BLOCK_M=256, num_warps=8 - ## For fp8, pick BLOCK_M=128, num_warps=4 + ## Tuning for MI300 ## TODO (zhanglx): add tuning infra for FA - BLOCK_M = 128 if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256 + BLOCK_M = 128 BLOCK_N = 128 waves_per_eu = 2 num_warps = BLOCK_M // 32