From 0f4c24593b8072b795a617b61259e0c3fac1bafd Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 3 Feb 2025 12:39:01 -0800 Subject: [PATCH] Add compile time Kineto trace (#148) Summary: Generates Kineto trace focusing on compile time so that we can peek into compile time analysis. Fixes https://github.com/pytorch-labs/tritonbench/issues/117 Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/148 Test Plan: ``` $ python run.py --op softmax --num-inputs 1 --input-id 0 --metrics compile_trace --only triton_softmax ``` image Internal link only: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftritonbench%2Fcompile_time.json&bucket=tc_bench_ci With autotuning: image Internal test: ``` x_val triton_softmax-compile_trace ------- ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 2176 https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/tritonbench/tritonbench_triton_softmax_20250131_163154_6535183793.json&bucket=pyper_traces ``` Reviewed By: adamomainz Differential Revision: D68898238 Pulled By: xuzhao9 fbshipit-source-id: f2b79e21b6fd94ec017412b724625f45ec8a349d --- .../components/compile_time/__init__.py | 6 +- tritonbench/components/compile_time/trace.py | 49 +++++++++++++++- tritonbench/utils/triton_op.py | 56 +++++++++++++++++-- 3 files changed, 105 insertions(+), 6 deletions(-) diff --git a/tritonbench/components/compile_time/__init__.py b/tritonbench/components/compile_time/__init__.py index 50c8f45f..51e7dfdc 100644 --- a/tritonbench/components/compile_time/__init__.py +++ b/tritonbench/components/compile_time/__init__.py @@ -1 +1,5 @@ -from .trace import do_compile_time_in_task, fbcode_do_compile_time_in_task # noqa F401 +from .trace import ( # noqa F401 + do_compile_kineto_trace_in_task, + do_compile_time_in_task, + fbcode_do_compile_time_in_task, +) diff --git a/tritonbench/components/compile_time/trace.py b/tritonbench/components/compile_time/trace.py index fbb1d133..c0aeee24 100644 --- a/tritonbench/components/compile_time/trace.py +++ b/tritonbench/components/compile_time/trace.py @@ -1,11 +1,25 @@ -from typing import Callable, Dict +import os +import random +import string +from datetime import datetime +from functools import partial +from typing import Callable, Dict, Optional import torch +import torch.profiler as profiler from tritonbench.utils.env_utils import fresh_triton_cache, is_fbcode if is_fbcode(): from triton.fb.triton_util import triton_add_listener, TritonHook +DEFAULT_PROFILE_OPTS = { + "record_shapes": True, + "profile_memory": True, + "with_stack": True, + "with_flops": True, + "with_modules": True, +} + def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]: # not yet getting results that make sense to me @@ -37,3 +51,36 @@ def do_compile_time_in_task(fn: Callable) -> float: torch.cuda.synchronize() # Wait for the events to be recorded! latency_with_compile = start_event.elapsed_time(end_event) return latency_with_compile + + +def do_compile_kineto_trace_in_task( + fn: Callable, + profile_opts: Optional[Dict[str, bool]] = None, + output_dir: Optional[str] = None, +) -> Optional[str]: + """Profile compilation stage using Kineto.""" + activity_groups = [ + profiler.ProfilerActivity.CUDA, + profiler.ProfilerActivity.CPU, + ] + if not profile_opts: + profile_opts = DEFAULT_PROFILE_OPTS + prefix = f"tritonbench_{fn._name}" + name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" + trace_path = os.path.join(output_dir, name) + with fresh_triton_cache(): + with profiler.profile( + schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1), + activities=activity_groups, + record_shapes=profile_opts["record_shapes"], + profile_memory=profile_opts["profile_memory"], + with_stack=profile_opts["with_stack"], + with_flops=profile_opts["with_flops"], + with_modules=profile_opts["with_modules"], + on_trace_ready=( + partial(lambda name, prof: prof.export_chrome_trace(name), trace_path) + ), + ) as prof: + fn() + prof.step() + return trace_path diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 1c1abd8d..1a20c6ab 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -199,6 +199,8 @@ class BenchmarkOperatorMetrics: compile_time: Optional[float] = None # stage breakdown of compile times compile_time_by_stage: Optional[Dict[str, float]] = None + # compile time with kineto trace + compile_trace: Optional[str] = None # ncu trace file ncu_trace: Optional[str] = None # ncu replay file @@ -1145,6 +1147,10 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.compile_time = compile_time if compile_time_by_stage: metrics.compile_time_by_stage = compile_time_by_stage + if "compile_trace" in self.required_metrics: + metrics.compile_trace = self.compile_time( + input_id, fn_name, metrics, kineto_trace=True + ) if "ncu_trace" in self.required_metrics: metrics.ncu_trace = self.ncu_trace(input_id, fn_name) # Collect NCU metrics if any required metrics match the ncu analyzer @@ -1236,6 +1242,29 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.all_configs = self.all_configs(fn) if "kernel_source_hash" in self.required_metrics: metrics.kernel_source_hash = self.kernel_hash(fn) + if "_compile_time_kineto_trace_in_task" in self.required_metrics: + assert ( + self.required_metrics == ["_compile_time_kineto_trace_in_task"] + and len(self._only) == 1 + and (self._input_id is not None) + ), ( + "_compile_time_kineto_trace_in_task must be measured by itself. " + f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}" + ) + from tritonbench.components.compile_time import ( + do_compile_kineto_trace_in_task, + ) + + kineto_trace_output_dir = self.get_temp_path("kineto_trace") + kineto_trace_output_dir.mkdir(parents=True, exist_ok=True) + metrics.extra_metrics["_compile_time_kineto_trace_in_task"] = ( + do_compile_kineto_trace_in_task( + fn, output_dir=str(kineto_trace_output_dir) + ) + ) + self._compile_time_kineto_trace_in_task = metrics.extra_metrics[ + "_compile_time_kineto_trace_in_task" + ] if "_compile_time_in_task" in self.required_metrics: assert ( self.required_metrics == ["_compile_time_in_task"] @@ -1591,8 +1620,12 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str: ) def compile_time( - self, input_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics - ) -> float: + self, + input_id: int, + fn_name: str, + metrics: BenchmarkOperatorMetrics, + kineto_trace: bool = False, + ) -> Union[float, str]: # We need to spawn a subprocess when user wants to measure the compile time # of multiple sample inputs and backends. from tritonbench.operators.op_task import OpTask @@ -1611,12 +1644,27 @@ def compile_time( "--input-id", str(input_id), "--metrics", - "_compile_time_in_task", + ( + "_compile_time_in_task" + if not kineto_trace + else "_compile_time_kineto_trace_in_task" + ), ] ) - op_task = OpTask(name=self.name) + op_task = OpTask(name=self.name, save_output_dir=Path("/tmp/tritonbench")) op_task.make_operator_instance(args=op_task_args) op_task.run() + if kineto_trace: + kineto_trace_loc = op_task.get_attribute( + "_compile_time_kineto_trace_in_task" + ) + if IS_FBCODE: + from tritonbench.components.kineto.fb.run_utils import ( + manifold_upload_file, + ) + + return manifold_upload_file(kineto_trace_loc, perfdoctor=True) + return kineto_trace_loc if op_task.get_attribute("triton_hook_latency") is not None: compiled_time = op_task.get_attribute("triton_hook_latency") compile_time_by_stage = op_task.get_attribute("compile_time_by_stage")