Skip to content

Commit

Permalink
Remove none_throws usage from bundled_program
Browse files Browse the repository at this point in the history
Differential Revision: D67123067

Pull Request resolved: pytorch#7296
  • Loading branch information
tarun292 authored Dec 12, 2024
1 parent 957259e commit de74961
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
32 changes: 20 additions & 12 deletions devtools/bundled_program/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Dict, List, Optional, Sequence, Type, Union

import executorch.devtools.bundled_program.schema as bp_schema
from pyre_extensions import none_throws

import executorch.exir.schema as core_schema

Expand Down Expand Up @@ -44,10 +43,12 @@ class BundledProgram:

def __init__(
self,
executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
]],
executorch_program: Optional[
Union[
ExecutorchProgram,
ExecutorchProgramManager,
]
],
method_test_suites: Sequence[MethodTestSuite],
pte_file_path: Optional[str] = None,
):
Expand All @@ -59,18 +60,24 @@ def __init__(
pte_file_path: The path to pte file to deserialize program if executorch_program is not provided.
"""
if not executorch_program and not pte_file_path:
raise RuntimeError("Either executorch_program or pte_file_path must be provided")
raise RuntimeError(
"Either executorch_program or pte_file_path must be provided"
)

if executorch_program and pte_file_path:
raise RuntimeError("Only one of executorch_program or pte_file_path can be used")
raise RuntimeError(
"Only one of executorch_program or pte_file_path can be used"
)

method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
if executorch_program:
self._assert_valid_bundle(executorch_program, method_test_suites)
self.executorch_program: Optional[Union[
ExecutorchProgram,
ExecutorchProgramManager,
]] = executorch_program
self.executorch_program: Optional[
Union[
ExecutorchProgram,
ExecutorchProgramManager,
]
] = executorch_program
self._pte_file_path: Optional[str] = pte_file_path

self.method_test_suites = method_test_suites
Expand All @@ -88,7 +95,8 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram:
if self.executorch_program:
program = self._extract_program(self.executorch_program)
else:
with open(none_throws(self._pte_file_path), "rb") as f:
assert self._pte_file_path is not None
with open(self._pte_file_path, "rb") as f:
p_bytes = f.read()
program = _deserialize_pte_binary(p_bytes)

Expand Down
15 changes: 11 additions & 4 deletions devtools/bundled_program/test/test_bundle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# pyre-strict

import tempfile
import unittest
from typing import List
import tempfile

import executorch.devtools.bundled_program.schema as bp_schema

import torch
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_bundled_program(self) -> None:
bundled_program.serialize_to_schema().program,
bytes(_serialize_pte_binary(executorch_program.executorch_program)),
)

def test_bundled_program_from_pte(self) -> None:
executorch_program, method_test_suites = get_common_executorch_program()

Expand All @@ -82,11 +83,17 @@ def test_bundled_program_from_pte(self) -> None:
with open(executorch_model_path, "wb") as f:
f.write(executorch_program.buffer)

bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path)
bundled_program = BundledProgram(
executorch_program=None,
method_test_suites=method_test_suites,
pte_file_path=executorch_model_path,
)

method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)

for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
for plan_id in range(
len(executorch_program.executorch_program.execution_plan)
):
bundled_plan_test = (
bundled_program.serialize_to_schema().method_test_suites[plan_id]
)
Expand Down

0 comments on commit de74961

Please sign in to comment.