Skip to content

Commit

Permalink
[RELAND][export] Exempt autograd ops for predispatch export (pytorch#…
Browse files Browse the repository at this point in the history
…117448)

Summary: Reland of https://github.com/pytorch/pytorch/pull/116527/files

Test Plan: CI

Differential Revision: D52675324

Pull Request resolved: pytorch#117448
Approved by: https://github.com/ydwu4
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jan 16, 2024
1 parent 99e5474 commit 28be47c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# flake8: noqa
import copy
import dataclasses
import io
import unittest
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down
37 changes: 37 additions & 0 deletions test/export/test_safeguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,43 @@ def f3(a):
with self.assertRaises(RuntimeError):
export(f3, (torch.randn(10, requires_grad=False),))

def test_global_autograd_exempt_predispatch(self):
def f1(a):
with torch.no_grad():
b = a + a
return b

def f2(a):
with torch.enable_grad():
b = a + a
return b

def f3(a):
with torch.set_grad_enabled(False):
b = a + a
return b

def f4(a):
with torch.set_grad_enabled(True):
b = a + a
return b

a = torch.randn(10)

from torch.export._trace import _export

with torch.no_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)

with torch.enable_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)


if __name__ == "__main__":
run_tests()
16 changes: 16 additions & 0 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def get_filtered_export_db_tests():

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestSerialize(TestCase):
def test_predispatch_export_with_autograd_op(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
with torch.enable_grad():
return x + x

with torch.no_grad():
from torch.export._trace import _export
ep = _export(Foo(), (torch.ones(10),), pre_dispatch=True)

with self.assertRaisesRegex(SerializeError, "Failed serializing node _set_grad_enabled"):
torch.export.save(ep, io.BytesIO())

def test_serialize_multiple_returns_from_node(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions torch/_export/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]:
torch.sym_min,
torch.sym_not,
torch.sym_sqrt,
# TODO (tmanlaibaatar)
# Predispatch export is able to contain autograd ops.
# These will be modeled as HOO later
torch._C._set_grad_enabled

)

if not isinstance(op, _allowed_op_types()):
Expand Down
8 changes: 7 additions & 1 deletion torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,13 @@ def __init__(self, tracer):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
return self.tracer.create_node("call_function", func, args, {})
# It's for passing the export verifier which needs to verify the meta['val']
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
# instead of hardcoding it here.
node = self.tracer.create_node("call_function", func, args, {})
if func is torch._C._set_grad_enabled:
node.meta['val'] = None
return node
# Don't actually run the function! We just want to trace the calls
# into a graph. We don't actualy want to change global autograd state.
return func(*args, **kwargs)
Expand Down

0 comments on commit 28be47c

Please sign in to comment.