From db35ccf46334164d0b02e5be9450193d881edb40 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 21 Dec 2023 02:00:17 +0000 Subject: [PATCH] Revert "[innductor] make inductor work with new triton compile interface (#115878)" This reverts commit bbded928b3556cf5678edf8fa41109d418312bcc. Reverted https://github.com/pytorch/pytorch/pull/115878 on behalf of https://github.com/kit1980 due to Broke ROCm https://github.com/pytorch/pytorch/actions/runs/7282149837/job/19844618618 ([comment](https://github.com/pytorch/pytorch/pull/115878#issuecomment-1865369349)) --- test/inductor/test_indexing.py | 10 +-- torch/_dynamo/utils.py | 11 ---- torch/_dynamo/variables/functions.py | 17 +---- torch/_inductor/autotune_process.py | 11 +--- torch/_inductor/codegen/triton.py | 48 +-------------- torch/_inductor/codegen/triton_foreach.py | 2 - torch/_inductor/codegen/wrapper.py | 4 -- torch/_inductor/select_algorithm.py | 1 - torch/_inductor/triton_heuristics.py | 75 ++++++----------------- torch/_inductor/utils.py | 15 ++--- 10 files changed, 33 insertions(+), 161 deletions(-) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 4f1e044644bdb..311e3627252f7 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -2,7 +2,7 @@ import sympy from torch._inductor.codegen.cpp import cexpr -from torch._inductor.codegen.triton import texpr, TritonPrinter +from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr from torch._inductor.sizevars import SizeVarAllocator @@ -291,18 +291,14 @@ def test_print_Min_Max(self): (sympy.Min, "min"), (sympy.Max, "max"), ) - extra_arg = TritonPrinter._propagate_nan_arg() for f, s in cases: x = sympy.Symbol("x", integer=True) expr = f(-2, x) - self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x{extra_arg})") + self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x)") self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)") expr = f(x, 2 * x, 3 * x) - self.assertEqual( - texpr(expr), - f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x{extra_arg}){extra_arg})", - ) + self.assertEqual(texpr(expr), f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x))") self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7dd2d8f75b638..b8d44975a1484 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2374,14 +2374,3 @@ def to_fake_tensor(t, fake_mode): return fake_mode.from_tensor( t, static_shapes=False, symbolic_context=symbolic_context, source=source ) - - -def get_first_attr(obj, *attrs): - """ - Return the first available attribute or throw an exception if none is present. - """ - for attr in attrs: - if hasattr(obj, attr): - return getattr(obj, attr) - - raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c78acd4a6b59e..3974cd753fa57 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -10,7 +10,7 @@ from ..bytecode_transformation import create_call_function, create_rot_n from ..exc import unimplemented, Unsupported from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource -from ..utils import get_first_attr, make_cell +from ..utils import make_cell from .base import typestr, VariableTracker @@ -654,20 +654,9 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs): # We only support configs and keys arguments of triton.autotune # Make sure other arguments are defaulted defaults = inspect.signature(Autotuner).parameters - - # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. - # The call to get_first_attr is to maintain backward-compatibility. if ( - ( - "warmup" in defaults - and defaults["warmup"].default - != get_first_attr(kernel, "num_warmups", "warmup") - ) - or ( - "rep" in defaults - and defaults["rep"].default - != get_first_attr(kernel, "num_reps", "rep") - ) + ("warmup" in defaults and defaults["warmup"].default != kernel.warmup) + or ("rep" in defaults and defaults["rep"].default != kernel.rep) or ( "prune_configs_by" in defaults and defaults["prune_configs_by"].default diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 4b89a9ce9283b..84cd09628ec78 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -513,15 +513,6 @@ def make_run_fn( ) run_method = getattr(mod, self.kernel_name).run - extra_args = list(self.extra_args) - - # Newer version of triton add warmup argument to JITFunction.run. - # This code handles backward-compatibility. - warmup_arg = {} - import inspect - - if "warmup" in inspect.signature(run_method).parameters: - warmup_arg["warmup"] = False return functools.partial( run_method, @@ -529,9 +520,9 @@ def make_run_fn( output_tensor, *self.extra_args, grid=self.grid, - **warmup_arg, num_stages=self.num_stages, num_warps=self.num_warps, + stream=torch.cuda.current_stream().cuda_stream, ) def __str__(self) -> str: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0a10248a77d73..6a43b4f5ca8cf 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -10,7 +10,6 @@ import operator import os import textwrap -from functools import lru_cache from typing import ( Any, Callable, @@ -32,7 +31,6 @@ from torch._prims_common import is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import ValueRanges -from torch.utils._triton import has_triton_package from ..._dynamo.utils import counters from .. import config, ir, scheduler @@ -93,30 +91,6 @@ def _print_Where(self, expr): q = self.doprint(expr.args[2]) return f"tl.where({c}, {p}, {q})" - @staticmethod - @lru_cache(None) - def _propagate_nan_arg(): - """ - Newer triton version added propagate_nan as required argument for - tl.math.{min, max}. This method make inductor work with both old - and new version of triton. - """ - - if not has_triton_package(): - # some tests run under environment without triton installed want to - # check that the generated code is as expected. - return "" - import inspect - - import triton.language as tl - - if "propagate_nan" in inspect.signature(tl.math.min).parameters: - # tl.PropagateNan.NONE is the default - propagate_nan_arg = ", tl.PropagateNan.NONE" - else: - propagate_nan_arg = "" - return propagate_nan_arg - def _print_Min(self, expr): nargs = len(expr.args) if len(expr.args) == 1: @@ -125,7 +99,7 @@ def _print_Min(self, expr): mid = len(expr.args) // 2 a = self._print(sympy.Min(*expr.args[:mid])) b = self._print(sympy.Min(*expr.args[mid:])) - return f"tl.math.min({a}, {b}{TritonPrinter._propagate_nan_arg()})" + return f"tl.math.min({a}, {b})" def _print_Max(self, expr): nargs = len(expr.args) @@ -135,8 +109,7 @@ def _print_Max(self, expr): mid = len(expr.args) // 2 a = self._print(sympy.Max(*expr.args[:mid])) b = self._print(sympy.Max(*expr.args[mid:])) - - return f"tl.math.max({a}, {b}{TritonPrinter._propagate_nan_arg()})" + return f"tl.math.max({a}, {b})" def _print_Abs(self, expr): assert len(expr.args) == 1 @@ -2099,20 +2072,6 @@ def imports_for_benchmark_kernel(self): """ ) - @staticmethod - @lru_cache(None) - def gen_attr_descriptor_import(): - """ - import AttrsDescriptor if the triton version is new enough to have this - class defined. - """ - import triton.compiler.compiler - - if hasattr(triton.compiler.compiler, "AttrsDescriptor"): - return "from triton.compiler.compiler import AttrsDescriptor" - else: - return "" - def codegen_kernel(self, name=None): from triton import next_power_of_2 @@ -2158,9 +2117,6 @@ def codegen_kernel(self, name=None): from torch._inductor import triton_helpers """ ) - if self.gen_attr_descriptor_import(): - code.splice(self.gen_attr_descriptor_import()) - if config.benchmark_kernel: code.splice(self.imports_for_benchmark_kernel()) diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py index df75615747dca..8efd9fb6864a8 100644 --- a/torch/_inductor/codegen/triton_foreach.py +++ b/torch/_inductor/codegen/triton_foreach.py @@ -189,8 +189,6 @@ def codegen_kernel(self, name=None): from torch._inductor import triton_helpers """ ) - if TritonKernel.gen_attr_descriptor_import(): - code.splice(TritonKernel.gen_attr_descriptor_import()) argdefs, _, _ = self.args.python_argdefs() code.writeline(self.jit_line()) code.writeline( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 4a55cdf876c8c..f59b6bd194b6c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -926,10 +926,6 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): """, strip=True, ) - from .triton import TritonKernel - - if TritonKernel.gen_attr_descriptor_import(): - compile_wrapper.splice(TritonKernel.gen_attr_descriptor_import()) compile_wrapper.newline() from .common import SizeArg, TensorArg diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index e237494b9bfe9..fe468f18e0038 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -187,7 +187,6 @@ def hook(): "from torch._inductor.triton_heuristics import template", "from torch._inductor.utils import instance_descriptor", "from torch._inductor import triton_helpers", - TritonKernel.gen_attr_descriptor_import(), "", self.jit_line(), f"def {self.kernel_name}({', '.join(arg_defs)}):", diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 02e3954c65b3b..153a5e23378c0 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -18,7 +18,7 @@ import torch.autograd.profiler as autograd_profiler from torch._dynamo.device_interface import get_interface_for_device -from torch._dynamo.utils import dynamo_timed, get_first_attr +from torch._dynamo.utils import dynamo_timed from torch.utils._triton import has_triton_package from . import config @@ -44,17 +44,11 @@ from triton import Config from triton.runtime.autotuner import OutOfResources from triton.runtime.jit import KernelInterface - - try: - from triton.compiler.compiler import ASTSource - except ImportError: - ASTSource = None else: Config = object triton = None KernelInterface = object OutOfResources = object - ASTSource = None _NUM_THREADS_PER_WARP = 32 @@ -292,44 +286,14 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int] # Setting device_type="hip" required on ROCm to pass down to triton compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip" - device_type = compile_meta["device_type"] if warm_cache_only_with_cc: - cc = warm_cache_only_with_cc - else: - device_id = compile_meta["device"] - device_interface = get_interface_for_device(device_type) - device = torch.device(device_type, device_id) - cc = device_interface.get_compute_capability(device) - - compile_meta["cc"] = cc - - if ASTSource: - compile_args = ( - ASTSource( + return ( + triton.compile( self.fn, - compile_meta["signature"], - compile_meta["constants"], - compile_meta["configs"][0], + warm_cache_only=True, + cc=warm_cache_only_with_cc, + **compile_meta, ), - ) - - target = (device_type, cc) - options = { - "num_warps": compile_meta["num_warps"], - "num_stages": compile_meta["num_stages"], - "debug": compile_meta["debug"], - } - compile_kwargs = { - "target": target, - "options": options, - } - else: - compile_args = (self.fn,) - compile_kwargs = compile_meta - - if warm_cache_only_with_cc: - return ( - triton.compile(*compile_args, **compile_kwargs), None, ) @@ -337,8 +301,10 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int] with torch.cuda.device(compile_meta["device"]): # need to initialize context torch.cuda.synchronize(torch.cuda.current_device()) - - binary = triton.compile(*compile_args, **compile_kwargs) + binary = triton.compile( + self.fn, + **compile_meta, + ) binary._init_handles() call_args = [ @@ -355,14 +321,6 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int] "set_device": torch.cuda.set_device, "current_device": torch.cuda.current_device, } - - scope["runner"] = get_first_attr(binary, "run", "c_wrapper") - scope["function"] = get_first_attr(binary, "function", "cu_function") - cluster_dims = get_first_attr(binary, "cluster_dims", "clusterDims") - scope["cta_args"] = ( - (binary.num_ctas, *cluster_dims) if hasattr(binary, "num_ctas") else () - ) - exec( f""" def launcher({', '.join(def_args)}, grid, stream): @@ -371,10 +329,15 @@ def launcher({', '.join(def_args)}, grid, stream): else: grid_0, grid_1, grid_2 = grid - runner(grid_0, grid_1, grid_2, bin.num_warps, - *cta_args, bin.shared, - stream, function, None, None, None, - {', '.join(call_args)}) + if hasattr(bin, "num_ctas"): + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, + bin.num_ctas, *bin.clusterDims, bin.shared, + stream, bin.cu_function, None, None, None, + {', '.join(call_args)}) + else: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, + stream, bin.cu_function, None, None, None, + {', '.join(call_args)}) return bin """.lstrip(), scope, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index ceb6468da73dd..2fbc8fedcc962 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -609,16 +609,11 @@ def has_incompatible_cudagraph_ops(gm): return False -try: - from triton.compiler.compiler import AttrsDescriptor as instance_descriptor -except ImportError: - # To support older version of triton which does not have AttrsDescriptor - # class - instance_descriptor = collections.namedtuple( # type: ignore[no-redef] - "instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], - defaults=[tuple(), tuple(), tuple(), tuple()], - ) +instance_descriptor = collections.namedtuple( + "instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[tuple(), tuple(), tuple(), tuple()], +) @functools.lru_cache(None)