diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 642f2a4665..c4f483b28a 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -49,7 +49,6 @@ "intermediate_source/flask_rest_api_tutorial", "intermediate_source/text_to_speech_with_torchaudio", "intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release. - "intermediate_source/torch_export_tutorial" # reenable after 2940 is fixed. ] def tutorial_source_dirs() -> List[Path]: diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index dc5e226f86..b38277b88f 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -163,22 +163,6 @@ def forward(self, x): except Exception: tb.print_exc() -###################################################################### -# - unsupported Python language features (e.g. throwing exceptions, match statements) - -class Bad4(torch.nn.Module): - def forward(self, x): - try: - x = x + 1 - raise RuntimeError("bad") - except: - x = x + 2 - return x - -try: - export(Bad4(), (torch.randn(3, 3),)) -except Exception: - tb.print_exc() ###################################################################### # Non-Strict Export @@ -197,16 +181,6 @@ def forward(self, x): # ``strict=False`` flag. # # Looking at some of the previous examples which resulted in graph breaks: -# -# - Accessing tensor data with ``.data`` now works correctly - -class Bad2(torch.nn.Module): - def forward(self, x): - x.data[0, 0] = 3 - return x - -bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False) -print(bad2_nonstrict.module()(torch.ones(3, 3))) ###################################################################### # - Calling unsupported functions (such as many built-in functions) traces @@ -223,22 +197,6 @@ def forward(self, x): print(bad3_nonstrict) print(bad3_nonstrict.module()(torch.ones(3, 3))) -###################################################################### -# - Unsupported Python language features (such as throwing exceptions, match -# statements) now also get traced through. - -class Bad4(torch.nn.Module): - def forward(self, x): - try: - x = x + 1 - raise RuntimeError("bad") - except: - x = x + 2 - return x - -bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False) -print(bad4_nonstrict.module()(torch.ones(3, 3))) - ###################################################################### # However, there are still some features that require rewrites to the original @@ -349,7 +307,7 @@ def forward(self, x, y): # ``inp1`` has an unconstrained first dimension, but the size of the second # dimension must be in the interval [4, 18]. -from torch.export import Dim +from torch.export.dynamic_shapes import Dim inp1 = torch.randn(10, 10, 2) @@ -358,7 +316,7 @@ def forward(self, x): x = x[:, 2:] return torch.relu(x) -inp1_dim0 = Dim("inp1_dim0") +inp1_dim0 = Dim("inp1_dim0", max=50) inp1_dim1 = Dim("inp1_dim1", min=4, max=18) dynamic_shapes1 = { "x": {0: inp1_dim0, 1: inp1_dim1}, @@ -479,9 +437,7 @@ def forward(self, z, y): class DynamicShapesExample3(torch.nn.Module): def forward(self, x, y): - if x.shape[0] <= 16: - return x @ y[:, :16] - return y + return x @ y dynamic_shapes3 = { "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())}, @@ -536,6 +492,28 @@ def suggested_fixes(): print(exported_dynamic_shapes_example3.range_constraints) +###################################################################### +# In PyTorch v2.5, we also introduced an automatic way of determining dynamic +# shapes. In the case where you don't know the dynamism of tensors, or the +# relationship of dynamic shapes between input tensors, we can mark dimensions +# with `Dim.AUTO`, and export will determine the dynamism the input dimensions. +# Going back to the previous example, we can rewrite it as follows: + +inp4 = torch.randn(8, 16) +inp5 = torch.randn(16, 32) + +class DynamicShapesExample3(torch.nn.Module): + def forward(self, x, y): + return x @ y + +dynamic_shapes3_2 = { + "x": {i: Dim.AUTO for i in range(inp4.dim())}, + "y": {i: Dim.AUTO for i in range(inp5.dim())}, +} + +exported_dynamic_shapes_example_3_2 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_2) +print(exported_dynamic_shapes_example_3_2) + ###################################################################### # Custom Ops # ---------- @@ -548,7 +526,7 @@ def suggested_fixes(): # as with any other custom op @torch.library.custom_op("my_custom_library::custom_op", mutates_args={}) -def custom_op(input: torch.Tensor) -> torch.Tensor: +def custom_op(x: torch.Tensor) -> torch.Tensor: print("custom_op called!") return torch.relu(x)