Skip to content

Commit

Permalink
Stop using clear_frame as decorator (pytorch#132778)
Browse files Browse the repository at this point in the history
See pytorch#132073 for motivation

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#132778
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#132774
  • Loading branch information
ezyang authored and pytorchmergebot committed Aug 7, 2024
1 parent bb99008 commit 7c79e89
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ include_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=@.*(dynamo_timed|preserve_rng_state)',
'--pattern=@.*(dynamo_timed|preserve_rng_state|clear_frame)',
'--linter-name=CONTEXT_DECORATOR',
'--error-name=avoid context decorator',
"""--error-description=\
Expand Down
182 changes: 91 additions & 91 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,116 +1281,116 @@ def run_compiler_collective(self, tx):
tx.speculation_log.clear()
raise exc.CompileCollectiveRestartAnalysis

@torch._guards.TracingContext.clear_frame()
def compile_and_call_fx_graph(self, tx, rv, root):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
"""
from .decorators import disable
with torch._guards.TracingContext.clear_frame():
from .decorators import disable

assert self.should_exit
assert self.should_exit

self.run_compiler_collective(tx)
self.run_compiler_collective(tx)

name = unique_id("__compiled_fn")
name = unique_id("__compiled_fn")

assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
output_node = self.create_node(
"output",
"output",
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
{},
)
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
if not config.do_not_emit_runtime_asserts:
insert_deferred_runtime_asserts(
fx.GraphModule(root, self.graph),
self.shape_env,
name,
assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
output_node = self.create_node(
"output",
"output",
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
{},
)
# NB: deferred runtime asserts can keep graphargs live, so make sure
# those are inserted before pruning
self.remove_unused_graphargs()
ncalls = count_calls(self.graph)
counters["stats"]["calls_captured"] += ncalls

# free a bit of memory
self.real_value_cache.clear()

gm = _make_graph_module(root, self.graph)
for register_finalizer in self.register_finalizer_fns:
register_finalizer(gm)

gm.compile_subgraph_reason = self.compile_subgraph_reason
gm.meta[
"dynamo_flat_name_to_original_fqn"
] = self.dynamo_flat_name_to_original_fqn.copy()

graph_code_log.debug(
"%s",
lazy_format_graph_code(
name, gm, include_stride=True, include_device=True, colored=True
),
)
torch._logging.trace_structured(
"dynamo_output_graph",
lambda: {"sizes": self.get_graph_sizes_structured()},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
if not self.export:
import torch._functorch.config as _config

with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
if not config.do_not_emit_runtime_asserts:
insert_deferred_runtime_asserts(
fx.GraphModule(root, self.graph),
self.shape_env,
name,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode
# NB: deferred runtime asserts can keep graphargs live, so make sure
# those are inserted before pruning
self.remove_unused_graphargs()
ncalls = count_calls(self.graph)
counters["stats"]["calls_captured"] += ncalls

# free a bit of memory
self.real_value_cache.clear()

gm = _make_graph_module(root, self.graph)
for register_finalizer in self.register_finalizer_fns:
register_finalizer(gm)

gm.compile_subgraph_reason = self.compile_subgraph_reason
gm.meta[
"dynamo_flat_name_to_original_fqn"
] = self.dynamo_flat_name_to_original_fqn.copy()

graph_code_log.debug(
"%s",
lazy_format_graph_code(
name, gm, include_stride=True, include_device=True, colored=True
),
)
torch._logging.trace_structured(
"dynamo_output_graph",
lambda: {"sizes": self.get_graph_sizes_structured()},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
if not self.export:
import torch._functorch.config as _config

with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode

with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)

from torch.fx._lazy_graph_module import _LazyGraphModule
from torch.fx._lazy_graph_module import _LazyGraphModule

if isinstance(compiled_fn, _LazyGraphModule) or (
isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
):
# Since dynamo will run the forward method for the GraphModule shortly
# anyways, it does not hurt to do the real recompilation here if
# this is a _LazyGraphModule. This makes it easier for dynamo to
# optimize a _LazyGraphModule.

lazy_gm = (
compiled_fn
if isinstance(compiled_fn, _LazyGraphModule)
else compiled_fn.__self__ # type: ignore[attr-defined]
)
if isinstance(compiled_fn, _LazyGraphModule) or (
isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
):
# Since dynamo will run the forward method for the GraphModule shortly
# anyways, it does not hurt to do the real recompilation here if
# this is a _LazyGraphModule. This makes it easier for dynamo to
# optimize a _LazyGraphModule.

lazy_gm = (
compiled_fn
if isinstance(compiled_fn, _LazyGraphModule)
else compiled_fn.__self__ # type: ignore[attr-defined]
)

_LazyGraphModule.force_recompile(lazy_gm)
_LazyGraphModule.force_recompile(lazy_gm)

if not isinstance(compiled_fn, _LazyGraphModule):
# replace compiled_fn with the real forward method
compiled_fn = lazy_gm.forward
if not isinstance(compiled_fn, _LazyGraphModule):
# replace compiled_fn with the real forward method
compiled_fn = lazy_gm.forward

compiled_fn = disable(compiled_fn)
compiled_fn = disable(compiled_fn)

counters["stats"]["unique_graphs"] += 1
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)
counters["stats"]["unique_graphs"] += 1
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)

cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()
cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()

@property
def placeholders(self) -> List[fx.Node]:
Expand Down

0 comments on commit 7c79e89

Please sign in to comment.