Skip to content

Commit

Permalink
[tuner] dump outputs to tune_gemm/output (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Nov 13, 2024
1 parent 279cfa7 commit db2ca01
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
13 changes: 8 additions & 5 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import yaml
import os
from pathlib import Path
import glob

import torch
Expand All @@ -28,6 +29,7 @@
get_filename_compile_driver,
get_filename_myKernels,
get_filename_profile_driver,
get_output_dir,
name_to_tl_types,
patch_triton_compiler,
run_bash_command,
Expand Down Expand Up @@ -197,8 +199,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
kernel_name = get_filename_profile_driver(M, N, K, jobId)
if verbose:
print(f"profiling {kernel_name} on GPU {gpuid}")
here = Path(__file__).parent
run_bash_command_wrapper(
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
f"PYTHONPATH={here} rocprof --stats -o {get_output_dir()}/results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
capture=(verbose < 2))
jobId += ngpus

Expand All @@ -213,8 +216,8 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
# Generate kernel out of all configs
fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs,
rotating_buffer_size, bias_size)

run_bash_command(f"python {fname} -n {num_threads}", capture=(verbose < 2))
here = Path(__file__).parent
run_bash_command(f"PYTHONPATH={here} python {fname} -n {num_threads}", capture=(verbose < 2))
compile_end = datetime.now()
compile_time = compile_end - start_time
if verbose:
Expand Down Expand Up @@ -245,7 +248,7 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)]
df_prof = [pd.read_csv(f"{get_output_dir()}/results_{i}.csv") for i in range(jobs)]
for config in configs:
file_idx = idx % jobs
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))]
Expand Down Expand Up @@ -457,7 +460,7 @@ def parse_args():
args = parser.parse_args()
if not args.o:
if args.benchmark:
args.o = "benchmarking_results.csv"
args.o = f"{get_output_dir()}/benchmarking_results.csv"
else:
args.o = get_default_tuning_result_filename()

Expand Down
33 changes: 17 additions & 16 deletions python/perf-kernels/tools/tune_gemm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import triton
import triton.language as tl

import os
from pathlib import Path
import subprocess
from datetime import datetime

Expand Down Expand Up @@ -48,25 +48,27 @@ def run_bash_command(commandstring, capture=True):
return None


def get_output_dir():
output_dir = Path(__file__).parent.parent / "output"
if not output_dir.exists():
output_dir.mkdir()
return output_dir


def get_filename_myKernels():
path = os.path.dirname(os.path.abspath(__file__))
return f"{path}/../myKernels.py"
return f"{get_output_dir()}/myKernels.py"


def get_filename_without_extension(file_path):
base_name = os.path.basename(file_path)
file_name, _ = os.path.splitext(base_name)
return file_name
return Path(file_path).stem


def get_filename_compile_driver():
path = os.path.dirname(os.path.abspath(__file__))
return f"{path}/../compile_driver.py"
return f"{get_output_dir()}/compile_driver.py"


def get_filename_profile_driver(M, N, K, job_id):
path = os.path.dirname(os.path.abspath(__file__))
return f"{path}/../profile_driver_{M}x{N}x{K}_{job_id}.py"
return f"{get_output_dir()}/profile_driver_{M}x{N}x{K}_{job_id}.py"


def get_default_tuning_result_filename():
Expand All @@ -79,8 +81,7 @@ def get_default_tuning_result_filename():

dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S")

path = os.path.dirname(os.path.abspath(__file__))
defaultName = f"{path}/../tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml"
defaultName = f"{get_output_dir()}/tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml"
return defaultName


Expand All @@ -93,15 +94,15 @@ def patch_triton_compiler():
if not triton_location_str:
print("triton source not found from pip show triton")

triton_dir = triton_location_str[0].split()[-1].decode('utf-8')
triton_dir = Path(triton_location_str[0].split()[-1].decode('utf-8'))

jit_filename = os.path.join(triton_dir, "triton/runtime", "jit.py")
jit_filename = triton_dir / "triton" / "runtime" / "jit.py"

run_bash_command(f"sed -i 's/driver.active.get_current_device()/{device}/g' {jit_filename}")
run_bash_command(f"sed -i 's/driver.active.get_current_stream(device)/{stream}/g' {jit_filename}")

hip_driver_filename = os.path.join(triton_dir, "../third_party/amd/backend/", "driver.py")
cuda_driver_filename = os.path.join(triton_dir, "../third_party/nvidia/backend/", "driver.py")
hip_driver_filename = triton_dir.parent / "third_party" / "amd" / "backend" / "driver.py"
cuda_driver_filename = triton_dir.parent / "third_party" / "nvidia" / "backend" / "driver.py"

run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}")
run_bash_command(
Expand Down

0 comments on commit db2ca01

Please sign in to comment.