Skip to content

Commit

Permalink
Revert "[innductor] make inductor work with new triton compile interf…
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Dec 21, 2023
1 parent 65d3dde commit db35ccf
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 161 deletions.
10 changes: 3 additions & 7 deletions test/inductor/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sympy

from torch._inductor.codegen.cpp import cexpr
from torch._inductor.codegen.triton import texpr, TritonPrinter
from torch._inductor.codegen.triton import texpr
from torch._inductor.codegen.wrapper import pexpr

from torch._inductor.sizevars import SizeVarAllocator
Expand Down Expand Up @@ -291,18 +291,14 @@ def test_print_Min_Max(self):
(sympy.Min, "min"),
(sympy.Max, "max"),
)
extra_arg = TritonPrinter._propagate_nan_arg()
for f, s in cases:
x = sympy.Symbol("x", integer=True)
expr = f(-2, x)
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x{extra_arg})")
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x)")
self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)")

expr = f(x, 2 * x, 3 * x)
self.assertEqual(
texpr(expr),
f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x{extra_arg}){extra_arg})",
)
self.assertEqual(texpr(expr), f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x))")
self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})")


Expand Down
11 changes: 0 additions & 11 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,14 +2374,3 @@ def to_fake_tensor(t, fake_mode):
return fake_mode.from_tensor(
t, static_shapes=False, symbolic_context=symbolic_context, source=source
)


def get_first_attr(obj, *attrs):
"""
Return the first available attribute or throw an exception if none is present.
"""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)

raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
17 changes: 3 additions & 14 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, make_cell
from ..utils import make_cell
from .base import typestr, VariableTracker


Expand Down Expand Up @@ -654,20 +654,9 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs):
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner).parameters

# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
(
"warmup" in defaults
and defaults["warmup"].default
!= get_first_attr(kernel, "num_warmups", "warmup")
)
or (
"rep" in defaults
and defaults["rep"].default
!= get_first_attr(kernel, "num_reps", "rep")
)
("warmup" in defaults and defaults["warmup"].default != kernel.warmup)
or ("rep" in defaults and defaults["rep"].default != kernel.rep)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
Expand Down
11 changes: 1 addition & 10 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,25 +513,16 @@ def make_run_fn(
)

run_method = getattr(mod, self.kernel_name).run
extra_args = list(self.extra_args)

# Newer version of triton add warmup argument to JITFunction.run.
# This code handles backward-compatibility.
warmup_arg = {}
import inspect

if "warmup" in inspect.signature(run_method).parameters:
warmup_arg["warmup"] = False

return functools.partial(
run_method,
*input_tensors,
output_tensor,
*self.extra_args,
grid=self.grid,
**warmup_arg,
num_stages=self.num_stages,
num_warps=self.num_warps,
stream=torch.cuda.current_stream().cuda_stream,
)

def __str__(self) -> str:
Expand Down
48 changes: 2 additions & 46 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import operator
import os
import textwrap
from functools import lru_cache
from typing import (
Any,
Callable,
Expand All @@ -32,7 +31,6 @@
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package

from ..._dynamo.utils import counters
from .. import config, ir, scheduler
Expand Down Expand Up @@ -93,30 +91,6 @@ def _print_Where(self, expr):
q = self.doprint(expr.args[2])
return f"tl.where({c}, {p}, {q})"

@staticmethod
@lru_cache(None)
def _propagate_nan_arg():
"""
Newer triton version added propagate_nan as required argument for
tl.math.{min, max}. This method make inductor work with both old
and new version of triton.
"""

if not has_triton_package():
# some tests run under environment without triton installed want to
# check that the generated code is as expected.
return ""
import inspect

import triton.language as tl

if "propagate_nan" in inspect.signature(tl.math.min).parameters:
# tl.PropagateNan.NONE is the default
propagate_nan_arg = ", tl.PropagateNan.NONE"
else:
propagate_nan_arg = ""
return propagate_nan_arg

def _print_Min(self, expr):
nargs = len(expr.args)
if len(expr.args) == 1:
Expand All @@ -125,7 +99,7 @@ def _print_Min(self, expr):
mid = len(expr.args) // 2
a = self._print(sympy.Min(*expr.args[:mid]))
b = self._print(sympy.Min(*expr.args[mid:]))
return f"tl.math.min({a}, {b}{TritonPrinter._propagate_nan_arg()})"
return f"tl.math.min({a}, {b})"

def _print_Max(self, expr):
nargs = len(expr.args)
Expand All @@ -135,8 +109,7 @@ def _print_Max(self, expr):
mid = len(expr.args) // 2
a = self._print(sympy.Max(*expr.args[:mid]))
b = self._print(sympy.Max(*expr.args[mid:]))

return f"tl.math.max({a}, {b}{TritonPrinter._propagate_nan_arg()})"
return f"tl.math.max({a}, {b})"

def _print_Abs(self, expr):
assert len(expr.args) == 1
Expand Down Expand Up @@ -2099,20 +2072,6 @@ def imports_for_benchmark_kernel(self):
"""
)

@staticmethod
@lru_cache(None)
def gen_attr_descriptor_import():
"""
import AttrsDescriptor if the triton version is new enough to have this
class defined.
"""
import triton.compiler.compiler

if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
return "from triton.compiler.compiler import AttrsDescriptor"
else:
return ""

def codegen_kernel(self, name=None):
from triton import next_power_of_2

Expand Down Expand Up @@ -2158,9 +2117,6 @@ def codegen_kernel(self, name=None):
from torch._inductor import triton_helpers
"""
)
if self.gen_attr_descriptor_import():
code.splice(self.gen_attr_descriptor_import())

if config.benchmark_kernel:
code.splice(self.imports_for_benchmark_kernel())

Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/codegen/triton_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def codegen_kernel(self, name=None):
from torch._inductor import triton_helpers
"""
)
if TritonKernel.gen_attr_descriptor_import():
code.splice(TritonKernel.gen_attr_descriptor_import())
argdefs, _, _ = self.args.python_argdefs()
code.writeline(self.jit_line())
code.writeline(
Expand Down
4 changes: 0 additions & 4 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,6 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
""",
strip=True,
)
from .triton import TritonKernel

if TritonKernel.gen_attr_descriptor_import():
compile_wrapper.splice(TritonKernel.gen_attr_descriptor_import())
compile_wrapper.newline()

from .common import SizeArg, TensorArg
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def hook():
"from torch._inductor.triton_heuristics import template",
"from torch._inductor.utils import instance_descriptor",
"from torch._inductor import triton_helpers",
TritonKernel.gen_attr_descriptor_import(),
"",
self.jit_line(),
f"def {self.kernel_name}({', '.join(arg_defs)}):",
Expand Down
75 changes: 19 additions & 56 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch.autograd.profiler as autograd_profiler
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import dynamo_timed, get_first_attr
from torch._dynamo.utils import dynamo_timed
from torch.utils._triton import has_triton_package

from . import config
Expand All @@ -44,17 +44,11 @@
from triton import Config
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface

try:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
else:
Config = object
triton = None
KernelInterface = object
OutOfResources = object
ASTSource = None


_NUM_THREADS_PER_WARP = 32
Expand Down Expand Up @@ -292,53 +286,25 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
# Setting device_type="hip" required on ROCm to pass down to triton
compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip"

device_type = compile_meta["device_type"]
if warm_cache_only_with_cc:
cc = warm_cache_only_with_cc
else:
device_id = compile_meta["device"]
device_interface = get_interface_for_device(device_type)
device = torch.device(device_type, device_id)
cc = device_interface.get_compute_capability(device)

compile_meta["cc"] = cc

if ASTSource:
compile_args = (
ASTSource(
return (
triton.compile(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
compile_meta["configs"][0],
warm_cache_only=True,
cc=warm_cache_only_with_cc,
**compile_meta,
),
)

target = (device_type, cc)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
}
compile_kwargs = {
"target": target,
"options": options,
}
else:
compile_args = (self.fn,)
compile_kwargs = compile_meta

if warm_cache_only_with_cc:
return (
triton.compile(*compile_args, **compile_kwargs),
None,
)

# load binary to the correct device
with torch.cuda.device(compile_meta["device"]):
# need to initialize context
torch.cuda.synchronize(torch.cuda.current_device())

binary = triton.compile(*compile_args, **compile_kwargs)
binary = triton.compile(
self.fn,
**compile_meta,
)
binary._init_handles()

call_args = [
Expand All @@ -355,14 +321,6 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
"set_device": torch.cuda.set_device,
"current_device": torch.cuda.current_device,
}

scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
scope["function"] = get_first_attr(binary, "function", "cu_function")
cluster_dims = get_first_attr(binary, "cluster_dims", "clusterDims")
scope["cta_args"] = (
(binary.num_ctas, *cluster_dims) if hasattr(binary, "num_ctas") else ()
)

exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
Expand All @@ -371,10 +329,15 @@ def launcher({', '.join(def_args)}, grid, stream):
else:
grid_0, grid_1, grid_2 = grid
runner(grid_0, grid_1, grid_2, bin.num_warps,
*cta_args, bin.shared,
stream, function, None, None, None,
{', '.join(call_args)})
if hasattr(bin, "num_ctas"):
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps,
bin.num_ctas, *bin.clusterDims, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
else:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
return bin
""".lstrip(),
scope,
Expand Down
15 changes: 5 additions & 10 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,16 +609,11 @@ def has_incompatible_cudagraph_ops(gm):
return False


try:
from triton.compiler.compiler import AttrsDescriptor as instance_descriptor
except ImportError:
# To support older version of triton which does not have AttrsDescriptor
# class
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)
instance_descriptor = collections.namedtuple(
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)


@functools.lru_cache(None)
Expand Down

0 comments on commit db35ccf

Please sign in to comment.