From 3fcf0bd496fec427dedc600afbe670715f76342f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 13 Dec 2024 14:23:31 -0800 Subject: [PATCH] save api Differential Revision: D66523473 Pull Request resolved: https://github.com/pytorch/executorch/pull/7097 --- exir/program/_program.py | 14 ++++++++++++++ exir/program/test/test_program.py | 10 +++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index fd1d0aca3d..e428111844 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1532,3 +1532,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: reducing the peak memory usage. """ self._pte_data.write_to_file(open_file) + + def save(self, path: str) -> None: + """ + Saves the serialized ExecuTorch binary to the file at `path`. + """ + if path[-4:] != ".pte": + logging.error(f"Path {path} does not end with .pte") + raise ValueError(f"Path {path} does not end with .pte") + try: + with open(path, "wb") as file: + self.write_to_file(file) + logging.info(f"Saved exported program to {path}") + except Exception as e: + logging.error(f"Error while saving to {path}: {e}") diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 6ecb71762e..d07972f971 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pye-strict +# pyre-unsafe import copy import unittest @@ -803,3 +803,11 @@ def test_to_edge_with_preserved_ops_not_in_model(self): self._test_to_edge_with_preserved_ops( program, ops_not_to_decompose, expected_non_decomposed_edge_ops ) + + def test_save_fails(self): + model = TestLinear() + program = torch.export.export(model, model._get_random_inputs()) + edge = to_edge(program) + et = edge.to_executorch() + with self.assertRaises(ValueError): + _ = et.save("/tmp/test_save.pt")