From 28be47c2677e252c7dba3572a6b206dedd88e57c Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Tue, 16 Jan 2024 19:32:15 +0000 Subject: [PATCH] [RELAND][export] Exempt autograd ops for predispatch export (#117448) Summary: Reland of https://github.com/pytorch/pytorch/pull/116527/files Test Plan: CI Differential Revision: D52675324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117448 Approved by: https://github.com/ydwu4 --- test/export/test_export.py | 1 + test/export/test_safeguard.py | 37 +++++++++++++++++++++++++++ test/export/test_serialize.py | 16 ++++++++++++ torch/_export/verifier.py | 5 ++++ torch/fx/experimental/proxy_tensor.py | 8 +++++- 5 files changed, 66 insertions(+), 1 deletion(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index bc96c38a4e4004..6042273499dbf5 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2,6 +2,7 @@ # flake8: noqa import copy import dataclasses +import io import unittest from contextlib import contextmanager from dataclasses import dataclass diff --git a/test/export/test_safeguard.py b/test/export/test_safeguard.py index ea2a3668e1f4ce..c31b09b45a5546 100644 --- a/test/export/test_safeguard.py +++ b/test/export/test_safeguard.py @@ -82,6 +82,43 @@ def f3(a): with self.assertRaises(RuntimeError): export(f3, (torch.randn(10, requires_grad=False),)) + def test_global_autograd_exempt_predispatch(self): + def f1(a): + with torch.no_grad(): + b = a + a + return b + + def f2(a): + with torch.enable_grad(): + b = a + a + return b + + def f3(a): + with torch.set_grad_enabled(False): + b = a + a + return b + + def f4(a): + with torch.set_grad_enabled(True): + b = a + a + return b + + a = torch.randn(10) + + from torch.export._trace import _export + + with torch.no_grad(): + _export(f1, (a,), pre_dispatch=True) + _export(f2, (a,), pre_dispatch=True) + _export(f3, (a,), pre_dispatch=True) + _export(f4, (a,), pre_dispatch=True) + + with torch.enable_grad(): + _export(f1, (a,), pre_dispatch=True) + _export(f2, (a,), pre_dispatch=True) + _export(f3, (a,), pre_dispatch=True) + _export(f4, (a,), pre_dispatch=True) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 6c5b53103227fd..ce718277b1582c 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -49,6 +49,22 @@ def get_filtered_export_db_tests(): @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSerialize(TestCase): + def test_predispatch_export_with_autograd_op(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + with torch.enable_grad(): + return x + x + + with torch.no_grad(): + from torch.export._trace import _export + ep = _export(Foo(), (torch.ones(10),), pre_dispatch=True) + + with self.assertRaisesRegex(SerializeError, "Failed serializing node _set_grad_enabled"): + torch.export.save(ep, io.BytesIO()) + def test_serialize_multiple_returns_from_node(self) -> None: class MyModule(torch.nn.Module): def __init__(self): diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 9b7c5fa6a07939..d026267a79d331 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -159,6 +159,11 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: torch.sym_min, torch.sym_not, torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled + ) if not isinstance(op, _allowed_op_types()): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 57ec8a20e1b642..ec9b34f6df3a57 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -683,7 +683,13 @@ def __init__(self, tracer): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} if func in _side_effectful_need_to_be_preserved_pre_dispatch: - return self.tracer.create_node("call_function", func, args, {}) + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # instead of hardcoding it here. + node = self.tracer.create_node("call_function", func, args, {}) + if func is torch._C._set_grad_enabled: + node.meta['val'] = None + return node # Don't actually run the function! We just want to trace the calls # into a graph. We don't actualy want to change global autograd state. return func(*args, **kwargs)