Skip to content

Commit

Permalink
[Dynamo] Trace autograd.function in dynamo when inputs require grad (p…
Browse files Browse the repository at this point in the history
…ytorch#116358) (pytorch#116897)

For training graphs (when inputs require grad), previously, we would speculate the forward and backward graph to determine if there are any graph breaks, side effect and etc but would not actually use these speculated graphs. We would just insert a call function node on the graph and later rely on autograd's tracing.

This approach does not work for more generalized graphs like graphs that include user defined triton kernels because autograd is not able to do the higher order function conversation.

This PR speculates the forward and backward functions and emits them in a HOF that later gets used via templating mechanism.

While working on this PR, I have exposed some bugs in the current tracing due to trampoline functions losing the source information resulting in incorrect graphs being produced. I have fixed these source information bugs and killed the trampolines.

Pull Request resolved: pytorch#116897
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/voznesenskym
  • Loading branch information
oulgen authored and pytorchmergebot committed Jan 16, 2024
1 parent f20eaad commit 28bb31e
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 202 deletions.
76 changes: 75 additions & 1 deletion test/dynamo/test_autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda

if HAS_CUDA:
import triton
from torch.testing._internal.triton_utils import add_kernel


class CustomFunc1(torch.autograd.Function):
Expand Down Expand Up @@ -275,7 +281,7 @@ def test_stride_in_bwd(self):
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Illegal getattr invocation stride in strict mod",
".*HigherOrderOperator body's output must consist of tensors only",
):
opt_model(x)

Expand Down Expand Up @@ -856,6 +862,74 @@ def foo(x):
foo(torch.randn(2))
foo(torch.randn(2, requires_grad=True))

@requires_cuda()
@skipIfRocm
def test_triton_kernel_basic(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output

@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return x * grad_output, y * grad_output

@torch.compile(fullgraph=True, backend="inductor")
def f(x, y):
z = Add.apply(x, y)
return z

x = torch.randn(10, device="cuda", requires_grad=True)
y = torch.randn(10, device="cuda", requires_grad=True)
z = f(x, y)
loss = z.sum()
loss.backward()
self.assertEqual(x + y, z)

@requires_cuda()
@skipIfRocm
def test_triton_kernel_multiple_out(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
ctx.t1 = x
ctx.t2 = y
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output, x

@staticmethod
def backward(ctx, grad_output, old_x):
x, y = ctx.saved_tensors
x1 = ctx.t1
y1 = ctx.t2
return old_x * x * x1 * grad_output, y * y1 * grad_output

@torch.compile(fullgraph=True, backend="inductor")
def f(x, y):
z = Add.apply(x, y)
return z

x = torch.randn(10, device="cuda", requires_grad=True)
y = torch.randn(10, device="cuda", requires_grad=True)
z, _ = f(x, y)
loss = z.sum()
loss.backward()
self.assertEqual(x + y, z)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
8 changes: 1 addition & 7 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,18 +2238,12 @@ def check_inlineable(func):

result = skipfiles.check_verbose(func, is_inlined_call=True)
if result.skipped:
from torch._dynamo.variables.misc import (
produce_trampoline_autograd_apply,
produce_trampoline_autograd_bwd,
produce_trampoline_autograd_fwd,
)
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply

# _origin marks this as coming from an internal dynamo known function that is safe to
# trace through.
if hasattr(func.fn, "_origin") and func.fn._origin in [
produce_trampoline_autograd_fwd,
produce_trampoline_autograd_apply,
produce_trampoline_autograd_bwd,
]:
# Known sound
return skipfiles.SkipResult(False, "allowlist in dynamo known function")
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def build_key_value(k, v):
# handle aliased autograd function `apply` calls
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetAttrVariable(
AutogradFunctionVariable(value.__self__, source=self.source),
AutogradFunctionVariable(
value.__self__, source=AttrSource(self.source, member="__self__")
),
"apply",
)
elif np and isinstance(value, np.number):
Expand Down
Loading

0 comments on commit 28bb31e

Please sign in to comment.