From de749610895a5920d46c801dd5dea691832c03c6 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi <58826100+tarun292@users.noreply.github.com> Date: Thu, 12 Dec 2024 03:58:33 -0800 Subject: [PATCH] Remove none_throws usage from bundled_program Differential Revision: D67123067 Pull Request resolved: https://github.com/pytorch/executorch/pull/7296 --- devtools/bundled_program/core.py | 32 ++++++++++++------- .../bundled_program/test/test_bundle_data.py | 15 ++++++--- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/devtools/bundled_program/core.py b/devtools/bundled_program/core.py index c01d1c3c1b..2c930f06b7 100644 --- a/devtools/bundled_program/core.py +++ b/devtools/bundled_program/core.py @@ -9,7 +9,6 @@ from typing import Dict, List, Optional, Sequence, Type, Union import executorch.devtools.bundled_program.schema as bp_schema -from pyre_extensions import none_throws import executorch.exir.schema as core_schema @@ -44,10 +43,12 @@ class BundledProgram: def __init__( self, - executorch_program: Optional[Union[ - ExecutorchProgram, - ExecutorchProgramManager, - ]], + executorch_program: Optional[ + Union[ + ExecutorchProgram, + ExecutorchProgramManager, + ] + ], method_test_suites: Sequence[MethodTestSuite], pte_file_path: Optional[str] = None, ): @@ -59,18 +60,24 @@ def __init__( pte_file_path: The path to pte file to deserialize program if executorch_program is not provided. """ if not executorch_program and not pte_file_path: - raise RuntimeError("Either executorch_program or pte_file_path must be provided") + raise RuntimeError( + "Either executorch_program or pte_file_path must be provided" + ) if executorch_program and pte_file_path: - raise RuntimeError("Only one of executorch_program or pte_file_path can be used") + raise RuntimeError( + "Only one of executorch_program or pte_file_path can be used" + ) method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name) if executorch_program: self._assert_valid_bundle(executorch_program, method_test_suites) - self.executorch_program: Optional[Union[ - ExecutorchProgram, - ExecutorchProgramManager, - ]] = executorch_program + self.executorch_program: Optional[ + Union[ + ExecutorchProgram, + ExecutorchProgramManager, + ] + ] = executorch_program self._pte_file_path: Optional[str] = pte_file_path self.method_test_suites = method_test_suites @@ -88,7 +95,8 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram: if self.executorch_program: program = self._extract_program(self.executorch_program) else: - with open(none_throws(self._pte_file_path), "rb") as f: + assert self._pte_file_path is not None + with open(self._pte_file_path, "rb") as f: p_bytes = f.read() program = _deserialize_pte_binary(p_bytes) diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index 8e4a8ee651..b833903c2f 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -6,9 +6,10 @@ # pyre-strict +import tempfile import unittest from typing import List -import tempfile + import executorch.devtools.bundled_program.schema as bp_schema import torch @@ -73,7 +74,7 @@ def test_bundled_program(self) -> None: bundled_program.serialize_to_schema().program, bytes(_serialize_pte_binary(executorch_program.executorch_program)), ) - + def test_bundled_program_from_pte(self) -> None: executorch_program, method_test_suites = get_common_executorch_program() @@ -82,11 +83,17 @@ def test_bundled_program_from_pte(self) -> None: with open(executorch_model_path, "wb") as f: f.write(executorch_program.buffer) - bundled_program = BundledProgram(executorch_program=None, method_test_suites=method_test_suites, pte_file_path=executorch_model_path) + bundled_program = BundledProgram( + executorch_program=None, + method_test_suites=method_test_suites, + pte_file_path=executorch_model_path, + ) method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name) - for plan_id in range(len(executorch_program.executorch_program.execution_plan)): + for plan_id in range( + len(executorch_program.executorch_program.execution_plan) + ): bundled_plan_test = ( bundled_program.serialize_to_schema().method_test_suites[plan_id] )