Skip to content

Commit

Permalink
Support to init BundledProgram from pte file
Browse files Browse the repository at this point in the history
Differential Revision: D67013542

Pull Request resolved: pytorch#7278
  • Loading branch information
YIWENX14 authored Dec 11, 2024
1 parent 62d2e37 commit 8fc3f8c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
31 changes: 25 additions & 6 deletions devtools/bundled_program/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Optional, Sequence, Type, Union

import executorch.devtools.bundled_program.schema as bp_schema
from pyre_extensions import none_throws

import executorch.exir.schema as core_schema

Expand All @@ -19,7 +20,7 @@
from executorch.devtools.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION

from executorch.exir import ExecutorchProgram, ExecutorchProgramManager
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize import _deserialize_pte_binary, _serialize_pte_binary
from executorch.exir.tensor import get_scalar_type, scalar_type_enum, TensorSpec

# pyre-ignore
Expand All @@ -43,23 +44,35 @@ class BundledProgram:

def __init__(
self,
executorch_program: Union[
executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
],
]],
method_test_suites: Sequence[MethodTestSuite],
pte_file_path: Optional[str] = None,
):
"""Create BundledProgram by bundling the given program and method_test_suites together.
Args:
executorch_program: The program to be bundled.
method_test_suites: The testcases for certain methods to be bundled.
pte_file_path: The path to pte file to deserialize program if executorch_program is not provided.
"""
if not executorch_program and not pte_file_path:
raise RuntimeError("Either executorch_program or pte_file_path must be provided")

if executorch_program and pte_file_path:
raise RuntimeError("Only one of executorch_program or pte_file_path can be used")

method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
self._assert_valid_bundle(executorch_program, method_test_suites)
if executorch_program:
self._assert_valid_bundle(executorch_program, method_test_suites)
self.executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
]] = executorch_program
self._pte_file_path: Optional[str] = pte_file_path

self.executorch_program = executorch_program
self.method_test_suites = method_test_suites

# This is the cache for bundled program in schema type.
Expand All @@ -72,7 +85,13 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
if self._bundled_program_in_schema is not None:
return self._bundled_program_in_schema

program = self._extract_program(self.executorch_program)
if self.executorch_program:
program = self._extract_program(self.executorch_program)
else:
with open(none_throws(self._pte_file_path), "rb") as f:
p_bytes = f.read()
program = _deserialize_pte_binary(p_bytes)

bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []

# Emit data and metadata of bundled tensor
Expand Down
39 changes: 38 additions & 1 deletion devtools/bundled_program/test/test_bundle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import unittest
from typing import List

import tempfile
import executorch.devtools.bundled_program.schema as bp_schema

import torch
Expand Down Expand Up @@ -73,6 +73,43 @@ def test_bundled_program(self) -> None:
bundled_program.serialize_to_schema().program,
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
)

def test_bundled_program_from_pte(self) -> None:
executorch_program, method_test_suites = get_common_executorch_program()

with tempfile.TemporaryDirectory() as tmp_dir:
executorch_model_path = f"{tmp_dir}/executorch_model.pte"
with open(executorch_model_path, "wb") as f:
f.write(executorch_program.buffer)

bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)

method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)

for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
bundled_plan_test = (
bundled_program.serialize_to_schema().method_test_suites[plan_id]
)
method_test_suite = method_test_suites[plan_id]

self.assertEqual(
len(bundled_plan_test.test_cases), len(method_test_suite.test_cases)
)
for bundled_program_ioset, method_test_case in zip(
bundled_plan_test.test_cases, method_test_suite.test_cases
):
self.assertIOsetDataEqual(
bundled_program_ioset.inputs, method_test_case.inputs
)
self.assertIOsetDataEqual(
bundled_program_ioset.expected_outputs,
method_test_case.expected_outputs,
)

self.assertEqual(
bundled_program.serialize_to_schema().program,
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
)

def test_bundled_miss_methods(self) -> None:
executorch_program, method_test_suites = get_common_executorch_program()
Expand Down

0 comments on commit 8fc3f8c

Please sign in to comment.