Skip to content

Commit

Permalink
Don't add non-integer Triton kernel arg 1 to equal_to_1 (pytorch#123886)
Browse files Browse the repository at this point in the history
Summary: Triton compiler adds constnat argument 1 to `equal_to_1` [only when it's an int](https://github.com/openai/triton/blob/8c5e33c77ef83e0cb99c744e58842930e602df31/python/triton/runtime/jit.py#L275). Here we restrict Inductor's `equal_to_1` in the same way.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py -k test_triton_kernel_equal_to_1_float_arg
...
----------------------------------------------------------------------
Ran 1 test in 6.528s

OK

$ python test/inductor/test_triton_kernels.py -k test_triton_kernel_equal_to_1_arg
...
----------------------------------------------------------------------
Ran 2 tests in 10.142s

OK
```

Pull Request resolved: pytorch#123886
Approved by: https://github.com/oulgen
ghstack dependencies: pytorch#123703
  • Loading branch information
aakhundov authored and pytorchmergebot committed Apr 14, 2024
1 parent 19f5033 commit 03a05e7
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 3 deletions.
39 changes: 39 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
add_kernel_2d_autotuned,
add_kernel_autotuned,
add_kernel_with_optional_param,
add_kernel_with_scaling,
)

if IS_WINDOWS and IS_CI:
Expand Down Expand Up @@ -1846,6 +1847,44 @@ def forward(self, x, y):

self.check_model(Model(), example_inputs)

@skipIfRocm
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")

class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.empty_like(x)
n_elements = x.numel()
scaling_factor = (n_elements**0) / 1.0
add_kernel_with_scaling[(n_elements,)](
x,
y,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out

dynamic_shapes = None
if dynamic:
dim0_xy = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x": {0: dim0_xy, 1: None},
"y": {0: dim0_xy, 1: None},
}
example_inputs = (
torch.randn(2, device=self.device),
torch.randn(2, device=self.device),
)
self.check_model(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
)

def test_shifted_constraint_ranges(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions test/inductor/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,36 @@ def f(x, y):
self.assertTrue("equal_to_1=(3,)" in sources[0])
self.assertEqual(compiled_out, eager_out)

@requires_cuda
@skipIfRocm
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
def f(x, y):
out = torch.empty_like(x)
n_elements = x.numel()
scaling_factor = (n_elements**0) / 1.0
add_kernel_with_scaling[(n_elements,)](
x,
y,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out

x = torch.randn(2, device="cuda")
y = torch.randn(2, device="cuda")
eager_out = f(x, y)
compiled_out, sources = run_and_get_code(
torch.compile(f, dynamic=dynamic), x, y
)

# float 1.0 (both literal or symbolic)
# should not be added to equal_to_1
self.assertTrue("equal_to_1=()" in sources[0])
self.assertEqual(compiled_out, eager_out)

@requires_cuda
@skipIfRocm
def test_triton_kernel_with_imported_symbol(self):
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/codegen/cpp_wrapper_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def generate_args_decl(self, call_args):
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
self.writeline(f"auto {var_name} = {arg};")
elif isinstance(arg, sympy.Float):
self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
elif isinstance(arg, sympy.Expr):
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
elif is_int(arg):
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/codegen/triton_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List, Optional

import sympy

import torch

from .. import config
Expand Down Expand Up @@ -36,7 +38,7 @@ def signature_of(arg: KernelArgType, *, size_dtype: str) -> str:
# From triton/runtime/jit.py
# `None` is nullptr. Implicitly convert to *i8.
return "*i8"
elif isinstance(arg.expr, float):
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
if size_dtype == "tl.int32":
return "i32"
Expand Down Expand Up @@ -118,7 +120,7 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
i
for i, arg in zip(indices, args)
if isinstance(arg, SizeArg)
and arg.expr is not None
and isinstance(arg.expr, (int, sympy.Integer))
and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
)
# ids_of_folded_args is set from equal_to_1
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,11 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
)
else:
signature.append(SizeArg(key, arg))
if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type]
if isinstance(
arg, (int, sympy.Integer)
) and V.graph.sizevars.statically_known_equals(
arg, 1 # type: ignore[arg-type]
):
equal_to_1_arg_idx.append(idx)
index_dtype = "tl.int32"
triton_meta = {
Expand Down
18 changes: 18 additions & 0 deletions torch/testing/_internal/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ def add_kernel_2d_autotuned(
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)

@triton.jit
def add_kernel_with_scaling(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
scaling_factor,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = (x + y) * scaling_factor
tl.store(out_ptr + offsets, output, mask=mask)

@triton.jit
def mul2_kernel(
in_ptr0,
Expand Down

0 comments on commit 03a05e7

Please sign in to comment.