diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index a9e3c81794..3e904b8b58 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -63,6 +63,7 @@ def export( ignore_to_out_var_failure: bool = False, dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, capture_config=None, + extract_constant_segment: bool = True, ) -> "ExportedModule": """ Creates a new ExportedModule for the specified module class. @@ -166,6 +167,7 @@ def return_wrapper(): dynamic_memory_planning_mode=dynamic_memory_planning_mode, memory_planning_pass=memory_planning_pass, to_out_var_pass=ToOutVarPass(ignore_to_out_var_failure), + extract_constant_segment=extract_constant_segment, ) ) ) diff --git a/test/models/export_program.py b/test/models/export_program.py index c0744146c9..edec584bc2 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -155,7 +155,9 @@ def get_method_names_to_export() -> List[str]: # -def export_module_to_program(module_class: Type[nn.Module]): +def export_module_to_program( + module_class: Type[nn.Module], extract_constant_segment: bool +): """Exports the module and returns the serialized program data.""" # Look for an optional @staticmethod that defines custom trace params. export_kwargs: Dict[str, Any] = {} @@ -167,7 +169,12 @@ def export_module_to_program(module_class: Type[nn.Module]): methods = module_class.get_method_names_to_export() else: methods = ["forward"] - module = ExportedModule.export(module_class, methods, **export_kwargs) + module = ExportedModule.export( + module_class, + methods, + extract_constant_segment=extract_constant_segment, + **export_kwargs, + ) return module.executorch_program.buffer @@ -205,10 +212,16 @@ def main() -> None: # Export and write to the output files. os.makedirs(args.outdir, exist_ok=True) for module_name, module_class in module_names_to_classes.items(): - outfile = os.path.join(args.outdir, f"{module_name}.pte") - with open(outfile, "wb") as fp: - fp.write(export_module_to_program(module_class)) - print(f"Exported {module_name} and wrote program data to {outfile}") + for extract_constant_segment in (True, False): + suffix = "" if extract_constant_segment else "-no-constant-segment" + outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte") + with open(outfile, "wb") as fp: + fp.write( + export_module_to_program( + module_class, extract_constant_segment=extract_constant_segment + ) + ) + print(f"Exported {module_name} and wrote program data to {outfile}") if __name__ == "__main__": diff --git a/test/models/targets.bzl b/test/models/targets.bzl index 8401efcddb..80c9cb9a6c 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -72,7 +72,11 @@ def define_common_targets(): runtime.genrule( name = "exported_programs", cmd = "$(exe :export_program) --modules " + ",".join(MODULES_TO_EXPORT) + " --outdir $OUT", - outs = {fname + ".pte": [fname + ".pte"] for fname in MODULES_TO_EXPORT}, + outs = { + fname + seg_suffix + ".pte": [fname + seg_suffix + ".pte"] + for fname in MODULES_TO_EXPORT + for seg_suffix in ["", "-no-constant-segment"] + }, default_outs = ["."], visibility = [ "//executorch/...",