Skip to content

Commit

Permalink
save api
Browse files Browse the repository at this point in the history
Differential Revision: D66523473

Pull Request resolved: pytorch#7097
  • Loading branch information
JacobSzwejbka authored Dec 13, 2024
1 parent 8460d42 commit 3fcf0bd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
14 changes: 14 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,3 +1532,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
reducing the peak memory usage.
"""
self._pte_data.write_to_file(open_file)

def save(self, path: str) -> None:
"""
Saves the serialized ExecuTorch binary to the file at `path`.
"""
if path[-4:] != ".pte":
logging.error(f"Path {path} does not end with .pte")
raise ValueError(f"Path {path} does not end with .pte")
try:
with open(path, "wb") as file:
self.write_to_file(file)
logging.info(f"Saved exported program to {path}")
except Exception as e:
logging.error(f"Error while saving to {path}: {e}")
10 changes: 9 additions & 1 deletion exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pye-strict
# pyre-unsafe

import copy
import unittest
Expand Down Expand Up @@ -803,3 +803,11 @@ def test_to_edge_with_preserved_ops_not_in_model(self):
self._test_to_edge_with_preserved_ops(
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)

def test_save_fails(self):
model = TestLinear()
program = torch.export.export(model, model._get_random_inputs())
edge = to_edge(program)
et = edge.to_executorch()
with self.assertRaises(ValueError):
_ = et.save("/tmp/test_save.pt")

0 comments on commit 3fcf0bd

Please sign in to comment.