Skip to content

Commit

Permalink
Prepare for merging _export/exported_program.py and export/exported_p…
Browse files Browse the repository at this point in the history
…rogram.py (pytorch#4338)

Summary:
Pull Request resolved: pytorch#4338

- Prepare for merging _export/exported_program.py and export/exported_program.py
- Change the callers to _export/exported_program.py
- remove duplicated no-op function in `fbcode/executorch/exir/serde/upgrade.py`

Reviewed By: zhxchen17

Differential Revision: D60052318

fbshipit-source-id: 778dc303f08207132953028d1ada058a326c37b4
  • Loading branch information
yushangdi authored and facebook-github-bot committed Jul 23, 2024
1 parent ae1f098 commit 93c56cb
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 36 deletions.
2 changes: 1 addition & 1 deletion backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
CompileSpec,
PreprocessResult,
)
from torch._export.exported_program import ExportedProgram
from torch.export.exported_program import ExportedProgram

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down
2 changes: 1 addition & 1 deletion examples/apple/mps/scripts/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time

import torch
from torch._export.exported_program import ExportedProgram
from torch.export.exported_program import ExportedProgram


def assert_outputs_equal(model_output, ref_output):
Expand Down
2 changes: 1 addition & 1 deletion exir/capture/_unlift.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):


def unlift_exported_program_lifted_states(
ep: torch._export.exported_program.ExportedProgram,
ep: torch.export.exported_program.ExportedProgram,
):
new_gm = copy.deepcopy(ep.graph_module)

Expand Down
2 changes: 1 addition & 1 deletion exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import sympy

import torch
import torch._export.exported_program
import torch.export.exported_program
import torch.export.exported_program as ep
from torch._export.serde.schema import SchemaVersion
from torch._export.verifier import load_verifier
Expand Down
32 changes: 0 additions & 32 deletions exir/serde/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@
from typing import Dict, List, Optional, Tuple

import torch
import torch._export.exported_program as ep
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._subclasses import FakeTensor
from torch.export import export
from torch.fx.node import Argument, Target
from torch.library import Library
from torch.utils._pytree import tree_unflatten

lib = Library("aten", "FRAGMENT")
impl_lib = Library("aten", "IMPL")
Expand Down Expand Up @@ -212,33 +208,5 @@ def register_old_op(name: str, schema: str, impl_str: str):

return upgrader_passes

def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram:
"""Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of
operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the
upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the
upgrader. After all passes are applied, the exported program will be upgraded to the target version.
"""
if not self.upgrader_passes:
return exported_program

args = [
n.meta.get("val", None)
for n in exported_program.graph.nodes
if n.op == "placeholder"
]
args_real_tensors = [
(
torch.ones(tuple(arg.size()), dtype=arg.dtype)
if isinstance(arg, FakeTensor)
else arg
)
for arg in args
]
assert exported_program.call_spec.in_spec is not None
args, kwargs = tree_unflatten(
args_real_tensors, exported_program.call_spec.in_spec
)
assert kwargs == {}

def upgrade(self, exported_program):
return exported_program

0 comments on commit 93c56cb

Please sign in to comment.