Skip to content

Commit

Permalink
Update ExportedProgram ctor callsites. (pytorch#4347)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4347

as title

Reviewed By: angelayi

Differential Revision: D60074221

fbshipit-source-id: 9987d65227f7ff4819827fd1d5772900bf609467
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Jul 23, 2024
1 parent 09cfc92 commit 6153b1b
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 20 deletions.
8 changes: 5 additions & 3 deletions backends/xnnpack/xnnpack_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ def preprocess(
range_constraints=edge_program.range_constraints,
module_call_graph=edge_program.module_call_graph,
example_inputs=edge_program.example_inputs,
verifier=EXIREdgeDialectVerifier(
edge_compile_config=xnnpack_edge_compile_config, class_only=True
),
constants=edge_program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=xnnpack_edge_compile_config, class_only=True
)
],
)

passes = []
Expand Down
2 changes: 1 addition & 1 deletion exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,6 @@ def to_backend(
range_constraints=copy.deepcopy(edge_program.range_constraints),
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
example_inputs=None,
verifier=edge_program.verifier,
constants=new_constants,
verifiers=[edge_program.verifier],
)
2 changes: 1 addition & 1 deletion exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def convert_to_fake(x):
)
],
example_inputs=None,
verifier=EXIRATenDialectVerifierBase,
verifiers=[EXIRATenDialectVerifierBase],
)
return ExirExportedProgram(ep, False)

Expand Down
6 changes: 3 additions & 3 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule"
module_call_graph=copy.deepcopy(
self._original_exported_program.module_call_graph
),
verifier=copy.deepcopy(self._original_exported_program.verifier),
constants=self._original_exported_program.constants,
verifiers=[copy.deepcopy(self._original_exported_program.verifier)],
)

res = LoweredBackendModule(
Expand Down Expand Up @@ -322,7 +322,7 @@ def program(
range_constraints=lowered_exported_program.range_constraints,
module_call_graph=lowered_exported_program.module_call_graph,
example_inputs=None,
verifier=lowered_exported_program.verifier,
verifiers=[lowered_exported_program.verifier],
)
if memory_planning is None:
memory_planning = MemoryPlanningPass("greedy")
Expand Down Expand Up @@ -616,8 +616,8 @@ def create_exported_program_from_submodule(
),
)
],
verifier=owning_program.verifier,
constants=subgraph_constants,
verifiers=[owning_program.verifier],
)


Expand Down
2 changes: 1 addition & 1 deletion exir/program/_fake_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram:
state_dict=new_state_dict,
range_constraints=copy.deepcopy(real_exported_program.range_constraints),
module_call_graph=copy.deepcopy(real_exported_program.module_call_graph),
verifier=real_exported_program.verifier,
constants=real_exported_program.constants,
verifiers=[real_exported_program.verifier],
)
return fake_exported_program

Expand Down
26 changes: 15 additions & 11 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
range_constraints=_get_updated_range_constraints(transformed_gm),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
constants=self.constants,
verifiers=[self.verifier],
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
Expand Down Expand Up @@ -590,8 +590,8 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
range_constraints=ep.exported_program.range_constraints,
module_call_graph=ep.exported_program.module_call_graph,
example_inputs=ep.exported_program.example_inputs,
verifier=get_aten_verifier(enable=config._check_ir_validity),
constants=ep.exported_program.constants,
verifiers=[get_aten_verifier(enable=config._check_ir_validity)],
),
False,
)
Expand Down Expand Up @@ -626,11 +626,13 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
range_constraints=new_ep.exported_program.range_constraints,
module_call_graph=new_ep.exported_program.module_call_graph,
example_inputs=new_ep.exported_program.example_inputs,
verifier=EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
),
constants=new_ep.exported_program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
)
],
)
new_ep.after_to_edge_passes = True
return new_ep
Expand Down Expand Up @@ -708,12 +710,14 @@ def _generate_edge_program(
range_constraints=program.range_constraints,
module_call_graph=program.module_call_graph,
example_inputs=program.example_inputs,
verifier=EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
),
constants=program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
)
],
)
# Lift the tensor constants created in ScalarToTensorPass
edge_program = lift_constant_tensor_pass(edge_program)
Expand Down

0 comments on commit 6153b1b

Please sign in to comment.