Skip to content

Commit

Permalink
Add compile time Kineto trace (#148)
Browse files Browse the repository at this point in the history
Summary:
Generates Kineto trace focusing on compile time so that we can peek into compile time analysis.

Fixes #117

Pull Request resolved: #148

Test Plan:
```
$ python run.py --op softmax --num-inputs 1 --input-id 0 --metrics compile_trace --only triton_softmax
```

<img width="632" alt="image" src="https://github.com/user-attachments/assets/bfacd02d-2033-435a-9836-08b6d8bc03ae" />

Internal link only: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftritonbench%2Fcompile_time.json&bucket=tc_bench_ci

With autotuning:

<img width="1111" alt="image" src="https://github.com/user-attachments/assets/47982b37-5597-41c2-95b1-096315e5f5ea" />

Internal test:

```
  x_val                                                                                                                                             triton_softmax-compile_trace
-------  -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
   2176  https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/tritonbench/tritonbench_triton_softmax_20250131_163154_6535183793.json&bucket=pyper_traces
```

Reviewed By: adamomainz

Differential Revision: D68898238

Pulled By: xuzhao9

fbshipit-source-id: f2b79e21b6fd94ec017412b724625f45ec8a349d
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Feb 3, 2025
1 parent 9052263 commit 0f4c245
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 6 deletions.
6 changes: 5 additions & 1 deletion tritonbench/components/compile_time/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .trace import do_compile_time_in_task, fbcode_do_compile_time_in_task # noqa F401
from .trace import ( # noqa F401
do_compile_kineto_trace_in_task,
do_compile_time_in_task,
fbcode_do_compile_time_in_task,
)
49 changes: 48 additions & 1 deletion tritonbench/components/compile_time/trace.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from typing import Callable, Dict
import os
import random
import string
from datetime import datetime
from functools import partial
from typing import Callable, Dict, Optional

import torch
import torch.profiler as profiler
from tritonbench.utils.env_utils import fresh_triton_cache, is_fbcode

if is_fbcode():
from triton.fb.triton_util import triton_add_listener, TritonHook

DEFAULT_PROFILE_OPTS = {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
}


def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]:
# not yet getting results that make sense to me
Expand Down Expand Up @@ -37,3 +51,36 @@ def do_compile_time_in_task(fn: Callable) -> float:
torch.cuda.synchronize() # Wait for the events to be recorded!
latency_with_compile = start_event.elapsed_time(end_event)
return latency_with_compile


def do_compile_kineto_trace_in_task(
fn: Callable,
profile_opts: Optional[Dict[str, bool]] = None,
output_dir: Optional[str] = None,
) -> Optional[str]:
"""Profile compilation stage using Kineto."""
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
if not profile_opts:
profile_opts = DEFAULT_PROFILE_OPTS
prefix = f"tritonbench_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
trace_path = os.path.join(output_dir, name)
with fresh_triton_cache():
with profiler.profile(
schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
activities=activity_groups,
record_shapes=profile_opts["record_shapes"],
profile_memory=profile_opts["profile_memory"],
with_stack=profile_opts["with_stack"],
with_flops=profile_opts["with_flops"],
with_modules=profile_opts["with_modules"],
on_trace_ready=(
partial(lambda name, prof: prof.export_chrome_trace(name), trace_path)
),
) as prof:
fn()
prof.step()
return trace_path
56 changes: 52 additions & 4 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class BenchmarkOperatorMetrics:
compile_time: Optional[float] = None
# stage breakdown of compile times
compile_time_by_stage: Optional[Dict[str, float]] = None
# compile time with kineto trace
compile_trace: Optional[str] = None
# ncu trace file
ncu_trace: Optional[str] = None
# ncu replay file
Expand Down Expand Up @@ -1145,6 +1147,10 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.compile_time = compile_time
if compile_time_by_stage:
metrics.compile_time_by_stage = compile_time_by_stage
if "compile_trace" in self.required_metrics:
metrics.compile_trace = self.compile_time(
input_id, fn_name, metrics, kineto_trace=True
)
if "ncu_trace" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
# Collect NCU metrics if any required metrics match the ncu analyzer
Expand Down Expand Up @@ -1236,6 +1242,29 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.all_configs = self.all_configs(fn)
if "kernel_source_hash" in self.required_metrics:
metrics.kernel_source_hash = self.kernel_hash(fn)
if "_compile_time_kineto_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_compile_time_kineto_trace_in_task"]
and len(self._only) == 1
and (self._input_id is not None)
), (
"_compile_time_kineto_trace_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from tritonbench.components.compile_time import (
do_compile_kineto_trace_in_task,
)

kineto_trace_output_dir = self.get_temp_path("kineto_trace")
kineto_trace_output_dir.mkdir(parents=True, exist_ok=True)
metrics.extra_metrics["_compile_time_kineto_trace_in_task"] = (
do_compile_kineto_trace_in_task(
fn, output_dir=str(kineto_trace_output_dir)
)
)
self._compile_time_kineto_trace_in_task = metrics.extra_metrics[
"_compile_time_kineto_trace_in_task"
]
if "_compile_time_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_compile_time_in_task"]
Expand Down Expand Up @@ -1591,8 +1620,12 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str:
)

def compile_time(
self, input_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics
) -> float:
self,
input_id: int,
fn_name: str,
metrics: BenchmarkOperatorMetrics,
kineto_trace: bool = False,
) -> Union[float, str]:
# We need to spawn a subprocess when user wants to measure the compile time
# of multiple sample inputs and backends.
from tritonbench.operators.op_task import OpTask
Expand All @@ -1611,12 +1644,27 @@ def compile_time(
"--input-id",
str(input_id),
"--metrics",
"_compile_time_in_task",
(
"_compile_time_in_task"
if not kineto_trace
else "_compile_time_kineto_trace_in_task"
),
]
)
op_task = OpTask(name=self.name)
op_task = OpTask(name=self.name, save_output_dir=Path("/tmp/tritonbench"))
op_task.make_operator_instance(args=op_task_args)
op_task.run()
if kineto_trace:
kineto_trace_loc = op_task.get_attribute(
"_compile_time_kineto_trace_in_task"
)
if IS_FBCODE:
from tritonbench.components.kineto.fb.run_utils import (
manifold_upload_file,
)

return manifold_upload_file(kineto_trace_loc, perfdoctor=True)
return kineto_trace_loc
if op_task.get_attribute("triton_hook_latency") is not None:
compiled_time = op_task.get_attribute("triton_hook_latency")
compile_time_by_stage = op_task.get_attribute("compile_time_by_stage")
Expand Down

0 comments on commit 0f4c245

Please sign in to comment.