Skip to content

Commit

Permalink
[Inductor] Fix arguments passed to triton kernel launch hooks (pytorc…
Browse files Browse the repository at this point in the history
…h#128732)

`binary.launch_enter_hook` is treated as an instance method and will add a `self` argument to the hooks.
`CompiledKernel.launch_enter_hook` is a static method, which matches the hook calling convention of profilers (i.e., a single `LazyDict` argument only).

Pull Request resolved: pytorch#128732
Approved by: https://github.com/shunting314, https://github.com/bertmaher
  • Loading branch information
Jokeren authored and pytorchmergebot committed Jun 18, 2024
1 parent a0e1e20 commit d9c294c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/inductor/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def test_inductor_profiling_triton_hooks(self):

hooks_called = {"enter": False, "exit": False}

def launch_enter_hook(*args):
def launch_enter_hook(lazy_dict):
hooks_called["enter"] = True

def launch_exit_hook(*args):
def launch_exit_hook(lazy_dict):
hooks_called["exit"] = True

CompiledKernel.launch_enter_hook = launch_enter_hook
Expand Down
5 changes: 3 additions & 2 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

if triton is not None:
from triton import Config
from triton.compiler import CompiledKernel
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface

Expand Down Expand Up @@ -453,8 +454,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool):
scope = {
"grid_meta": cfg.kwargs,
"bin": binary,
"launch_enter_hook": binary.launch_enter_hook,
"launch_exit_hook": binary.launch_exit_hook,
"launch_enter_hook": CompiledKernel.launch_enter_hook,
"launch_exit_hook": CompiledKernel.launch_exit_hook,
"metadata": binary.packed_metadata
if hasattr(binary, "packed_metadata")
else binary.metadata,
Expand Down

0 comments on commit d9c294c

Please sign in to comment.