diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index d2ff71dd73bb6..9d0270a9aae8d 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -158,10 +158,10 @@ def test_inductor_profiling_triton_hooks(self): hooks_called = {"enter": False, "exit": False} - def launch_enter_hook(*args): + def launch_enter_hook(lazy_dict): hooks_called["enter"] = True - def launch_exit_hook(*args): + def launch_exit_hook(lazy_dict): hooks_called["exit"] = True CompiledKernel.launch_enter_hook = launch_enter_hook diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5396ccf3e70d5..82a25392b5e95 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -50,6 +50,7 @@ if triton is not None: from triton import Config + from triton.compiler import CompiledKernel from triton.runtime.autotuner import OutOfResources from triton.runtime.jit import KernelInterface @@ -453,8 +454,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): scope = { "grid_meta": cfg.kwargs, "bin": binary, - "launch_enter_hook": binary.launch_enter_hook, - "launch_exit_hook": binary.launch_exit_hook, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, "metadata": binary.packed_metadata if hasattr(binary, "packed_metadata") else binary.metadata,