Skip to content

Commit

Permalink
Automated tutorials push
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 24, 2025
1 parent 05766fc commit 56c67af
Show file tree
Hide file tree
Showing 181 changed files with 11,996 additions and 10,272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,224 @@ def add_fn(x, y):
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

######################################################################
# Composibility and Limitations
# Composability
# -------------------------------------------------------------------
#
# User-defined Triton kernels do not automatically support all PyTorch
# subsystems. This can be seen in the following use cases:
#
# - Adding a CPU fallback
# - Adding a ``FlopCounter`` formula
# - Composing with Tensor Subclasses
#
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
#
# ``triton_op is`` a structured way of defining a custom operator that is backed by one
# or more Triton kernels: like regular custom operators (``torch.library.custom_op``),
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
#
# Here’s a chart of which API to use when integrating Triton kernels with PyTorch.
#
# .. list-table::
# :header-rows: 1
#
# * -
# - Triton kernel (no explicit ``torch.library`` wrapper)
# - ``torch.library.triton_op``
# - ``torch.library.custom_op``
# * - Supports inference
# - Yes
# - Yes
# - Yes
# * - Supports training
# - In the majority of cases
# - Yes
# - Yes
# * - Supports ``torch.compile``
# - Yes
# - Yes
# - Yes
# * - Supports ``torch.compile(fullgraph=True)``
# - In the majority of cases
# - In the majority of cases
# - In all cases
# * - Does torch.compile trace into the implementation?
# - Yes
# - Yes
# - No
# * - Supports AOTInductor
# - Yes
# - Yes
# - No
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
# - No
# - Yes
# - Yes

######################################################################
# Wrapping Triton kernels with ``triton_op``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels.
# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel.

from torch.library import triton_op, wrap_triton

@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out

@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)

def sin_triton(x):
out = torch.empty_like(x)
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out

######################################################################
# You can invoke the ``triton_op`` in one of the following two ways.

x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)

assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())

######################################################################
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.

y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())

######################################################################
# Adding training support
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
# with ``torch.compile``).

def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_input * x.cos()

def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

######################################################################
# Note that the backward must be a composition of PyTorch-understood operators.
# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well:

@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)

@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out

def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_input * mycos(x)

def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

######################################################################
# Adding a CPU Fallback
# ^^^^^^^^^^^^^^^^^^^^^
# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:

@mysin.register_kernel("cpu")
def _(x):
return torch.sin(x)

x = torch.randn(3)
y = mysin(x)
assert torch.allclose(y, x.sin())

######################################################################
# The fallback must be composed of PyTorch operators.

######################################################################
# Adding a FlopCounter Formula
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# To specify how many flops the triton kernel reports under PyTorch's flop counter,
# use ``register_flop_formula``.

from torch.utils.flop_counter import FlopCounterMode, register_flop_formula

@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel

x = torch.randn(3, device="cuda")

#########################################################
# ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__.
# Before running the code below, make sure you have ``tabulate`` installed or install by
# running ``pip install tabulate``.
#
# >>> with FlopCounterMode() as flop_counter:
# >>> y = mysin(x)

######################################################################
# Limitations
# --------------------------------------------------------------------
#
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
# You can use these features together to build complex, high-performance models.
#
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
# user-defined Triton kernels in tensor subclasses and other advanced features.
#
# However, there are certain limitations to be aware of:
#
# * **Tensor Subclasses:** Currently, there is no support for
# tensor subclasses and other advanced features.
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
Expand Down
Loading

0 comments on commit 56c67af

Please sign in to comment.