From 6153b1bf7b1a6547454c57e40b7d8a2beea6fcde Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 23 Jul 2024 09:24:06 -0700 Subject: [PATCH] Update ExportedProgram ctor callsites. (#4347) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4347 as title Reviewed By: angelayi Differential Revision: D60074221 fbshipit-source-id: 9987d65227f7ff4819827fd1d5772900bf609467 --- backends/xnnpack/xnnpack_preprocess.py | 8 +++++--- exir/backend/backend_api.py | 2 +- exir/capture/_capture.py | 2 +- exir/lowered_backend_module.py | 6 +++--- exir/program/_fake_program.py | 2 +- exir/program/_program.py | 26 +++++++++++++++----------- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 19e0528afd..5d4c05a535 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -105,10 +105,12 @@ def preprocess( range_constraints=edge_program.range_constraints, module_call_graph=edge_program.module_call_graph, example_inputs=edge_program.example_inputs, - verifier=EXIREdgeDialectVerifier( - edge_compile_config=xnnpack_edge_compile_config, class_only=True - ), constants=edge_program.constants, + verifiers=[ + EXIREdgeDialectVerifier( + edge_compile_config=xnnpack_edge_compile_config, class_only=True + ) + ], ) passes = [] diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 64391fc18a..2e7f1b3cdf 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -399,6 +399,6 @@ def to_backend( range_constraints=copy.deepcopy(edge_program.range_constraints), module_call_graph=copy.deepcopy(edge_program.module_call_graph), example_inputs=None, - verifier=edge_program.verifier, constants=new_constants, + verifiers=[edge_program.verifier], ) diff --git a/exir/capture/_capture.py b/exir/capture/_capture.py index 00ec8be286..6ae5bd40d8 100644 --- a/exir/capture/_capture.py +++ b/exir/capture/_capture.py @@ -355,7 +355,7 @@ def convert_to_fake(x): ) ], example_inputs=None, - verifier=EXIRATenDialectVerifierBase, + verifiers=[EXIRATenDialectVerifierBase], ) return ExirExportedProgram(ep, False) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index d10d8e242e..54d7380203 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -91,8 +91,8 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule" module_call_graph=copy.deepcopy( self._original_exported_program.module_call_graph ), - verifier=copy.deepcopy(self._original_exported_program.verifier), constants=self._original_exported_program.constants, + verifiers=[copy.deepcopy(self._original_exported_program.verifier)], ) res = LoweredBackendModule( @@ -322,7 +322,7 @@ def program( range_constraints=lowered_exported_program.range_constraints, module_call_graph=lowered_exported_program.module_call_graph, example_inputs=None, - verifier=lowered_exported_program.verifier, + verifiers=[lowered_exported_program.verifier], ) if memory_planning is None: memory_planning = MemoryPlanningPass("greedy") @@ -616,8 +616,8 @@ def create_exported_program_from_submodule( ), ) ], - verifier=owning_program.verifier, constants=subgraph_constants, + verifiers=[owning_program.verifier], ) diff --git a/exir/program/_fake_program.py b/exir/program/_fake_program.py index ce3eaf86ca..da1cee7c1a 100644 --- a/exir/program/_fake_program.py +++ b/exir/program/_fake_program.py @@ -49,8 +49,8 @@ def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram: state_dict=new_state_dict, range_constraints=copy.deepcopy(real_exported_program.range_constraints), module_call_graph=copy.deepcopy(real_exported_program.module_call_graph), - verifier=real_exported_program.verifier, constants=real_exported_program.constants, + verifiers=[real_exported_program.verifier], ) return fake_exported_program diff --git a/exir/program/_program.py b/exir/program/_program.py index 3f0f0155ec..3f963487fe 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -193,8 +193,8 @@ def _transform(self, *passes: PassType) -> "ExportedProgram": range_constraints=_get_updated_range_constraints(transformed_gm), module_call_graph=copy.deepcopy(self._module_call_graph), example_inputs=self.example_inputs, - verifier=self.verifier, constants=self.constants, + verifiers=[self.verifier], ) transformed_ep.graph_module.meta.update(self.graph_module.meta) transformed_ep.graph_module.meta.update(res.graph_module.meta) @@ -590,8 +590,8 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": range_constraints=ep.exported_program.range_constraints, module_call_graph=ep.exported_program.module_call_graph, example_inputs=ep.exported_program.example_inputs, - verifier=get_aten_verifier(enable=config._check_ir_validity), constants=ep.exported_program.constants, + verifiers=[get_aten_verifier(enable=config._check_ir_validity)], ), False, ) @@ -626,11 +626,13 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": range_constraints=new_ep.exported_program.range_constraints, module_call_graph=new_ep.exported_program.module_call_graph, example_inputs=new_ep.exported_program.example_inputs, - verifier=EXIREdgeDialectVerifier( - edge_compile_config=config, - class_only=True, - ), constants=new_ep.exported_program.constants, + verifiers=[ + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + ) + ], ) new_ep.after_to_edge_passes = True return new_ep @@ -708,12 +710,14 @@ def _generate_edge_program( range_constraints=program.range_constraints, module_call_graph=program.module_call_graph, example_inputs=program.example_inputs, - verifier=EXIREdgeDialectVerifier( - edge_compile_config=config, - class_only=True, - exception_list=ops_set_to_not_decompose, - ), constants=program.constants, + verifiers=[ + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + exception_list=ops_set_to_not_decompose, + ) + ], ) # Lift the tensor constants created in ScalarToTensorPass edge_program = lift_constant_tensor_pass(edge_program)