Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix export tutorial #3130

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"intermediate_source/flask_rest_api_tutorial",
"intermediate_source/text_to_speech_with_torchaudio",
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
"intermediate_source/torch_export_tutorial" # reenable after 2940 is fixed.
]

def tutorial_source_dirs() -> List[Path]:
Expand Down
74 changes: 26 additions & 48 deletions intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,6 @@ def forward(self, x):
except Exception:
tb.print_exc()

######################################################################
# - unsupported Python language features (e.g. throwing exceptions, match statements)

class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x

try:
export(Bad4(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

######################################################################
# Non-Strict Export
Expand All @@ -197,16 +181,6 @@ def forward(self, x):
# ``strict=False`` flag.
#
# Looking at some of the previous examples which resulted in graph breaks:
#
# - Accessing tensor data with ``.data`` now works correctly

class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x

bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))

######################################################################
# - Calling unsupported functions (such as many built-in functions) traces
Expand All @@ -223,22 +197,6 @@ def forward(self, x):
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))

######################################################################
# - Unsupported Python language features (such as throwing exceptions, match
# statements) now also get traced through.

class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x

bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))


######################################################################
# However, there are still some features that require rewrites to the original
Expand Down Expand Up @@ -349,7 +307,7 @@ def forward(self, x, y):
# ``inp1`` has an unconstrained first dimension, but the size of the second
# dimension must be in the interval [4, 18].

from torch.export import Dim
from torch.export.dynamic_shapes import Dim

inp1 = torch.randn(10, 10, 2)

Expand All @@ -358,7 +316,7 @@ def forward(self, x):
x = x[:, 2:]
return torch.relu(x)

inp1_dim0 = Dim("inp1_dim0")
inp1_dim0 = Dim("inp1_dim0", max=50)
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
"x": {0: inp1_dim0, 1: inp1_dim1},
Expand Down Expand Up @@ -479,9 +437,7 @@ def forward(self, z, y):

class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
return x @ y

dynamic_shapes3 = {
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
Expand Down Expand Up @@ -536,6 +492,28 @@ def suggested_fixes():

print(exported_dynamic_shapes_example3.range_constraints)

######################################################################
# In PyTorch v2.5, we also introduced an automatic way of determining dynamic
# shapes. In the case where you don't know the dynamism of tensors, or the
# relationship of dynamic shapes between input tensors, we can mark dimensions
# with `Dim.AUTO`, and export will determine the dynamism the input dimensions.
# Going back to the previous example, we can rewrite it as follows:

inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)

class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
return x @ y

dynamic_shapes3_2 = {
"x": {i: Dim.AUTO for i in range(inp4.dim())},
"y": {i: Dim.AUTO for i in range(inp5.dim())},
}

exported_dynamic_shapes_example_3_2 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_2)
print(exported_dynamic_shapes_example_3_2)

######################################################################
# Custom Ops
# ----------
Expand All @@ -548,7 +526,7 @@ def suggested_fixes():
# as with any other custom op

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
def custom_op(x: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)

Expand Down
Loading