diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index 1a99912d09..2703d71cc2 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -25,11 +25,12 @@ TestMPS, ) -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) + from executorch.exir import ExirExportedProgram from executorch.exir.backend.backend_api import to_backend from executorch.exir.tests.models import ( @@ -129,21 +130,20 @@ def forward(self, *args): f" -> Number of execution plans: {len(executorch_program.program.execution_plan)}" ) - bundled_inputs = [ - [m_inputs] for _ in range(len(executorch_program.program.execution_plan)) - ] - logging.info(" -> Bundled inputs generated successfully") - - output = m(*m_inputs) - expected_outputs = [ - [[output]] for _ in range(len(executorch_program.program.execution_plan)) + method_test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase(inputs=m_inputs, expected_outputs=model(*m_inputs)) + ], + ) ] - logging.info(" -> Bundled outputs generated successfully") - bundled_config = BundledConfig(["forward"], bundled_inputs, expected_outputs) - logging.info(" -> Bundled config generated successfully") + logging.info(" -> Test suites generated successfully") - bundled_program = create_bundled_program(executorch_program.program, bundled_config) + bundled_program = create_bundled_program( + executorch_program.program, method_test_suites + ) logging.info(" -> Bundled program generated successfully") bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 3ad6cc55d4..2eafede463 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -12,7 +12,7 @@ import torch from executorch.backends.apple.mps.mps_preprocess import MPSBackend -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, @@ -189,28 +189,23 @@ def forward(self, *args): logging.info( " -> Number of execution plans: {len(executorch_program.program.execution_plan)}" ) - bundled_inputs = [ - [sample_inputs] - for _ in range(len(executorch_program.program.execution_plan)) - ] - logging.info(" -> Bundled inputs generated successfully") - output = module(*sample_inputs) - expected_outputs = [ - [[output]] for _ in range(len(executorch_program.program.execution_plan)) + method_test_suites = [ + MethodTestSuite(method_name="forward", test_cases=[ + MethodTestCase(input=sample_inputs, expected_outputs=module(*sample_inputs)) + ]) ] - logging.info(" -> Bundled outputs generated successfully") - method_names = ["forward"] - bundled_config = BundledConfig(method_names, bundled_inputs, expected_outputs) + logging.info(" -> Test suites generated successfully") + bundled_program = create_bundled_program( - executorch_program.program, bundled_config + executorch_program.program, method_test_suites ) bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( bundled_program ) - filename = f"{func_name}.pte" + filename = f"{func_name}.bpte" logging.info(f"Step 5: Saving bundled program to {filename}...") with open(filename, "wb") as file: file.write(bundled_program_buffer) diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 8b91173fff..8d5da97f92 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -6,7 +6,7 @@ import unittest from random import randint -from typing import Any, Tuple +from typing import Any, List, Tuple import torch import torch.nn.functional as F @@ -26,7 +26,7 @@ # import the xnnpack backend implementation from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, @@ -101,14 +101,22 @@ def save_bundled_program(representative_inputs, program, ref_output, output_path niter = 1 print("generating bundled program inputs / outputs") - inputs_list = [list(representative_inputs) for _ in range(niter)] - expected_outputs_list = [ - [[ref_output] for x in inputs_list], + + method_test_cases: List[MethodTestCase] = [] + for _ in range(niter): + method_test_cases.append( + MethodTestCase( + inputs=representative_inputs, + expected_outputs=ref_output, + ) + ) + + method_test_suites = [ + MethodTestSuite(method_name="forward", method_test_cases=method_test_cases) ] - bundled_config = BundledConfig([inputs_list], expected_outputs_list) print("creating bundled program...") - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) print("serializing bundled program...") bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( diff --git a/bundled_program/config.py b/bundled_program/config.py index 7c66baf302..ed1c5c78f7 100644 --- a/bundled_program/config.py +++ b/bundled_program/config.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Any, get_args, List, Union +from typing import Any, get_args, List, Optional, Sequence, Union import torch from torch.utils._pytree import tree_flatten @@ -16,7 +16,7 @@ """ The data types currently supported for element to be bundled. It should be -consistent with the types in bundled_program.schema.BundledValue. +consistent with the types in bundled_program.schema.Value. """ ConfigValue: TypeAlias = Union[ torch.Tensor, @@ -28,15 +28,15 @@ """ The data type of the input for method single execution. """ -MethodInputType: TypeAlias = List[ConfigValue] +MethodInputType: TypeAlias = Sequence[ConfigValue] """ The data type of the output for method single execution. """ -MethodOutputType: TypeAlias = List[torch.Tensor] +MethodOutputType: TypeAlias = Sequence[torch.Tensor] """ -All supported types for input/expected output of test set. +All supported types for input/expected output of MethodTestCase. Namedtuple is also supported and listed implicity since it is a subclass of tuple. """ @@ -45,79 +45,40 @@ DataContainer: TypeAlias = Union[list, tuple, dict] -@dataclass -class ConfigIOSet: - """Type of data BundledConfig stored for each validation set.""" - - inputs: List[ConfigValue] - expected_outputs: List[ConfigValue] - - -@dataclass -class ConfigExecutionPlanTest: - """All info related to verify execution plan""" - - method_name: str - test_sets: List[ConfigIOSet] - - -class BundledConfig: - """All information needed to verify a model. - - Public Attributes: - execution_plan_tests: inputs, expected outputs, and other info for each execution plan verification. - attachments: Other info need to be attached. - """ +class MethodTestCase: + """Test case with inputs and expected outputs + The expected_outputs are optional and only required if the user wants to verify model outputs after execution.""" def __init__( self, - method_names: List[str], - inputs: List[List[MethodInputType]], - expected_outputs: List[List[MethodOutputType]], + inputs: MethodInputType, + expected_outputs: Optional[MethodOutputType] = None, ) -> None: - """Contruct the config given inputs and expected outputs + """Single test case for verifying specific method Args: - method_names: All method names need to be verified in program. - inputs: All sets of input need to be test on for all methods. Each list - of `inputs` is all sets which will be run on the method in the - program with corresponding method name. Each set of any `inputs` element should - contain all inputs required by eager_model with the same inference function - as corresponding execution plan for one-time execution. + input: All inputs required by eager_model with specific inference method for one-time execution. It is worth mentioning that, although both bundled program and ET runtime apis support setting input other than torch.tensor type, only the input in torch.tensor type will be actually updated in the method, and the rest of the inputs will just do a sanity check if they match the default value in method. - expected_outputs: Expected outputs for inputs sharing same index. The size of - expected_outputs should be the same as the size of inputs and provided method_names. + expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling. Returns: self """ - BundledConfig._check_io_type(inputs) - BundledConfig._check_io_type(expected_outputs) - - for m_name in method_names: - assert isinstance(m_name, str) - - assert len(method_names) == len(inputs) == len(expected_outputs), ( - "length of method_names, inputs and expected_outputs should match," - + " but got {}, {} and {}".format( - len(method_names), len(inputs), len(expected_outputs) - ) - ) - - self.execution_plan_tests: List[ - ConfigExecutionPlanTest - ] = BundledConfig._gen_execution_plan_tests( - method_names, inputs, expected_outputs - ) - - @staticmethod - # TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments. - # pyre-ignore - def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]: + # TODO(gasoonjia): Update type check logic. + # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check. + self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs) + self.expected_outputs: List[ConfigValue] = [] + if expected_outputs is not None: + # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check. + self.expected_outputs = self._flatten_and_sanity_check(expected_outputs) + + def _flatten_and_sanity_check( + self, unflatten_data: DataContainer + ) -> List[ConfigValue]: """Flat the given data and check its legality Args: @@ -126,6 +87,7 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]: Returns: flatten_data: Flatten data with legal type. """ + flatten_data, _ = tree_flatten(unflatten_data) for data in flatten_data: @@ -142,68 +104,15 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]: return flatten_data - @staticmethod - # pyre-ignore - def _check_io_type(test_data_program: List[List[Any]]) -> None: - """Check the type of each set of inputs or exepcted_outputs - - Each test set of inputs or expected_outputs will be put into the config - should be one of the sub-type in DataContainer. - - Args: - test_data_program: inputs or expected_outputs to be put into the config - to verify the whole program. - """ - for test_data_execution_plan in test_data_program: - for test_set in test_data_execution_plan: - assert isinstance(test_set, get_args(DataContainer)) - - @staticmethod - def _gen_execution_plan_tests( - method_names: List[str], - inputs: List[List[MethodInputType]], - expected_outputs: List[List[MethodOutputType]], - ) -> List[ConfigExecutionPlanTest]: - """Generate execution plan test given inputs, expected outputs for verifying each execution plan""" - - execution_plan_tests: List[ConfigExecutionPlanTest] = [] - - for ( - m_name, - inputs_per_plan_test, - expect_outputs_per_plan_test, - ) in zip(method_names, inputs, expected_outputs): - test_sets: List[ConfigIOSet] = [] - - # transfer I/O sets into ConfigIOSet for each execution plan - assert len(inputs_per_plan_test) == len(expect_outputs_per_plan_test), ( - "The number of input and expected output for identical execution plan should be the same," - + " but got {} and {}".format( - len(inputs_per_plan_test), len(expect_outputs_per_plan_test) - ) - ) - for unflatten_input, unflatten_expected_output in zip( - inputs_per_plan_test, expect_outputs_per_plan_test - ): - flatten_inputs = BundledConfig._tree_flatten(unflatten_input) - flatten_expected_output = BundledConfig._tree_flatten( - unflatten_expected_output - ) - test_sets.append( - ConfigIOSet( - inputs=flatten_inputs, expected_outputs=flatten_expected_output - ) - ) - - execution_plan_tests.append( - ConfigExecutionPlanTest( - method_name=m_name, - test_sets=test_sets, - ) - ) +@dataclass +class MethodTestSuite: + """All test info related to verify method - # sort the execution plan tests by method name to in line with core program emitter. - execution_plan_tests.sort(key=lambda x: x.method_name) + Attributes: + method_name: Name of the method to be verified. + test_cases: All test cases for verifying the method. + """ - return execution_plan_tests + method_name: str + test_cases: Sequence[MethodTestCase] diff --git a/bundled_program/core.py b/bundled_program/core.py index 006442a50e..936483d3c5 100644 --- a/bundled_program/core.py +++ b/bundled_program/core.py @@ -6,18 +6,14 @@ import ctypes import typing -from typing import Dict, List, Type +from typing import Dict, List, Sequence, Type import executorch.bundled_program.schema as bp_schema import executorch.exir.schema as core_schema import torch import torch.fx -from executorch.bundled_program.config import ( - BundledConfig, - ConfigExecutionPlanTest, - ConfigValue, -) +from executorch.bundled_program.config import ConfigValue, MethodTestSuite from executorch.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION from executorch.exir._serialize import _serialize_pte_binary @@ -124,56 +120,50 @@ def get_output_dtype( def assert_valid_bundle( program: core_schema.Program, - bundled_config: BundledConfig, + method_test_suites: Sequence[MethodTestSuite], ) -> None: - """Check if the program and BundledConfig matches each other. + """Check if the program and method_test_suites matches each other. Other checks not related to correspondence are done in config.py Args: program: The program to be bundled. - bundled_config: The config to be bundled. + method_test_suites: The testcases for specific methods to be bundled. """ - program_plan_id = 0 - bp_plan_id = 0 - method_name_of_program = {e.name for e in program.execution_plan} - method_name_of_bundled_config = { - t.method_name for t in bundled_config.execution_plan_tests - } + method_name_of_test_suites = {t.method_name for t in method_test_suites} - assert method_name_of_bundled_config.issubset( + assert method_name_of_test_suites.issubset( method_name_of_program ), f"All method names in bundled config should be found in program.execution_plan, \ - but {str(method_name_of_bundled_config - method_name_of_program)} does not include." + but {str(method_name_of_test_suites - method_name_of_program)} does not include." - # check if execution_plan_tests has been sorted in ascending alphabetical order of method name. - for bp_plan_id in range(1, len(bundled_config.execution_plan_tests)): + # check if method_test_suites has been sorted in ascending alphabetical order of method name. + for test_suite_id in range(1, len(method_test_suites)): assert ( - bundled_config.execution_plan_tests[bp_plan_id - 1].method_name - <= bundled_config.execution_plan_tests[bp_plan_id].method_name - ), f"The method name of BundledConfig should be sorted in ascending alphabetical \ - order of method name, but {bp_plan_id-1}-th and {bp_plan_id}-th methods aren't." + method_test_suites[test_suite_id - 1].method_name + <= method_test_suites[test_suite_id].method_name + ), f"The method name of test suite should be sorted in ascending alphabetical \ + order of method name, but {test_suite_id-1}-th and {test_suite_id}-th method_test_suite aren't." # Check if the inputs' type meet Program's requirement - while bp_plan_id < len(bundled_config.execution_plan_tests): + for method_test_suite in method_test_suites: - plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[ - bp_plan_id - ] - plan: core_schema.ExecutionPlan = program.execution_plan[program_plan_id] + # Get the method with same method name as method_test_suite + program_plan_id = -1 + for plan in program.execution_plan: + if plan.name == method_test_suite.method_name: + program_plan_id = program.execution_plan.index(plan) + break - # User does not provide testcases for current plan, skip it - if plan_test.method_name > plan.name: - program_plan_id += 1 - continue - - # Check if the method name in user provided test matches the one in the original program + # Raise Assertion Error if can not find the method with same method_name as method_test_suite in program. assert ( - plan_test.method_name == plan.name - ), f"BundledConfig has testcases for method {plan_test.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}." + program_plan_id != -1 + ), f"method_test_suites has testcases for method {method_test_suite.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}." + + plan = program.execution_plan[program_plan_id] # Check if the type of Program's input is supported for index in range(len(plan.inputs)): @@ -190,9 +180,11 @@ def assert_valid_bundle( ), "Only supports program with output in Tensor type." # Check if the I/O sets of each execution plan test match program's requirement. - for i in range(len(plan_test.test_sets)): - cur_plan_test_inputs = plan_test.test_sets[i].inputs - cur_plan_test_expected_outputs = plan_test.test_sets[i].expected_outputs + for i in range(len(method_test_suite.test_cases)): + cur_plan_test_inputs = method_test_suite.test_cases[i].inputs + cur_plan_test_expected_outputs = method_test_suite.test_cases[ + i + ].expected_outputs assert len(plan.inputs) == len( cur_plan_test_inputs @@ -259,38 +251,40 @@ def assert_valid_bundle( cur_plan_test_expected_outputs[j].dtype, ) - program_plan_id += 1 - bp_plan_id += 1 - def create_bundled_program( program: core_schema.Program, - bundled_config: BundledConfig, + method_test_suites: Sequence[MethodTestSuite], ) -> bp_schema.BundledProgram: - """Create bp_schema.BundledProgram by bundling the given program and bundled_config together. + """Create bp_schema.BundledProgram by bundling the given program and method_test_suites together. Args: program: The program to be bundled. - bundled_config: The config to be bundled. + method_test_suites: The testcases for certain methods to be bundled. - Returns: The `BundledProgram` variable contains given ExecuTorch program and test cases. + Returns: + The `BundledProgram` variable contains given ExecuTorch program and test cases. """ - assert_valid_bundle(program, bundled_config) + method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name) + + assert_valid_bundle(program, method_test_suites) - execution_plan_tests: List[bp_schema.BundledExecutionPlanTest] = [] + bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = [] # Emit data and metadata of bundled tensor - for plan_test in bundled_config.execution_plan_tests: - test_sets: List[bp_schema.BundledIOSet] = [] + for method_test_suite in method_test_suites: + bundled_test_cases: List[bp_schema.BundledMethodTestCase] = [] - # emit I/O sets for each execution plan test - for i in range(len(plan_test.test_sets)): + # emit I/O sets for each method test case + for i in range(len(method_test_suite.test_cases)): inputs: List[bp_schema.Value] = [] expected_outputs: List[bp_schema.Value] = [] - cur_plan_test_inputs = plan_test.test_sets[i].inputs - cur_plan_test_expected_outputs = plan_test.test_sets[i].expected_outputs + cur_plan_test_inputs = method_test_suite.test_cases[i].inputs + cur_plan_test_expected_outputs = method_test_suite.test_cases[ + i + ].expected_outputs for input_val in cur_plan_test_inputs: if type(input_val) == torch.Tensor: @@ -311,14 +305,16 @@ def create_bundled_program( TensorSpec.from_tensor(expected_output_tensor, const=True), expected_outputs, ) - test_sets.append( - bp_schema.BundledIOSet(inputs=inputs, expected_outputs=expected_outputs) + bundled_test_cases.append( + bp_schema.BundledMethodTestCase( + inputs=inputs, expected_outputs=expected_outputs + ) ) # emit the whole execution plan test - execution_plan_tests.append( - bp_schema.BundledExecutionPlanTest( - method_name=plan_test.method_name, test_sets=test_sets + bundled_method_test_suites.append( + bp_schema.BundledMethodTestSuite( + method_name=method_test_suite.method_name, test_cases=bundled_test_cases ) ) @@ -326,6 +322,6 @@ def create_bundled_program( return bp_schema.BundledProgram( version=BUNDLED_PROGRAM_SCHEMA_VERSION, - execution_plan_tests=execution_plan_tests, + method_test_suites=bundled_method_test_suites, program=program_bytes, ) diff --git a/bundled_program/schema.py b/bundled_program/schema.py index d7c7ee614c..08b86c61d3 100644 --- a/bundled_program/schema.py +++ b/bundled_program/schema.py @@ -56,7 +56,7 @@ class Value: @dataclass -class BundledIOSet: +class BundledMethodTestCase: """All inputs and referenced outputs needs for single verification.""" # All inputs required by Program for execution. Its length should be @@ -70,8 +70,8 @@ class BundledIOSet: @dataclass -class BundledExecutionPlanTest: - """Context for testing and verifying an exceution plan.""" +class BundledMethodTestSuite: + """Context for testing and verifying a Method.""" # The name of the method to test; e.g., "forward" for the forward() method # of an nn.Module. This name match a method defined by the ExecuTorch @@ -79,7 +79,7 @@ class BundledExecutionPlanTest: method_name: str # Sets of input/outputs to test with. - test_sets: List[BundledIOSet] + test_cases: List[BundledMethodTestCase] @dataclass @@ -90,9 +90,10 @@ class BundledProgram: version: int # Test sets and other meta datas to verify the whole program. - # Each BundledExecutionPlanTest should be used for the execution plan of program sharing same index. - # Its length should be equal to the number of execution plans in program. - execution_plan_tests: List[BundledExecutionPlanTest] + # Each BundledMethodTestSuite contains the test cases for one of the Method's + # present inside the ExecuTorchProgram of the same BundledProgram. The method_name + # present inside the BundledMethodTestSuite is what is used to link to the appropriate Method. + method_test_suites: List[BundledMethodTestSuite] - # The binary data of a serialized ExecuTorch program. + # The binary data of a serialized ExecuTorchProgram. program: bytes diff --git a/bundled_program/tests/common.py b/bundled_program/tests/common.py index a87870fd92..b5ae374a3c 100644 --- a/bundled_program/tests/common.py +++ b/bundled_program/tests/common.py @@ -7,16 +7,18 @@ # pyre-strict import random import string -from typing import List, Tuple, Union +from typing import List, Tuple import executorch.exir as exir import torch -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import ( + MethodInputType, + MethodOutputType, + MethodTestCase, + MethodTestSuite, +) from executorch.exir.schema import Program -# @manual=fbsource//third-party/pypi/typing-extensions:typing-extensions -from typing_extensions import TypeAlias - # A hacky integer to deal with a mismatch between execution plan and complier. # # Execution plan supports multiple types of inputs, like Tensor, Int, etc, @@ -37,13 +39,6 @@ DEFAULT_INT_INPUT = 2 -# Alias type for all datas model needs for single execution. -InputValues: TypeAlias = List[Union[torch.Tensor, int]] - -# Alias type for all datas model generates per execution. -OutputValues: TypeAlias = List[torch.Tensor] - - class SampleModel(torch.nn.Module): """An example model with multi-methods. Each method has multiple input and single output""" @@ -77,15 +72,16 @@ def get_rand_input_values( n_int: int, dtype: torch.dtype, n_sets_per_plan_test: int, - n_execution_plan_tests: int, -) -> List[List[InputValues]]: + n_method_test_suites: int, +) -> List[List[MethodInputType]]: + # pyre-ignore[7]: expected `List[List[List[Union[bool, float, int, Tensor]]]]` but got `List[List[List[Union[int, Tensor]]]]` return [ [ [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)] + [DEFAULT_INT_INPUT for _ in range(n_int)] for _ in range(n_sets_per_plan_test) ] - for _ in range(n_execution_plan_tests) + for _ in range(n_method_test_suites) ] @@ -94,40 +90,40 @@ def get_rand_output_values( sizes: List[List[int]], dtype: torch.dtype, n_sets_per_plan_test: int, - n_execution_plan_tests: int, -) -> List[List[OutputValues]]: + n_method_test_suites: int, +) -> List[List[MethodOutputType]]: + # pyre-ignore [7]: Expected `List[List[Sequence[Tensor]]]` but got `List[List[List[Tensor]]]`. return [ [ [(torch.rand(*sizes[i]) - 0.5).to(dtype) for i in range(n_tensors)] for _ in range(n_sets_per_plan_test) ] - for _ in range(n_execution_plan_tests) + for _ in range(n_method_test_suites) ] -def get_rand_method_names(n_execution_plan_tests: int) -> List[str]: +def get_rand_method_names(n_method_test_suites: int) -> List[str]: unique_strings = set() - while len(unique_strings) < n_execution_plan_tests: + while len(unique_strings) < n_method_test_suites: rand_str = "".join(random.choices(string.ascii_letters, k=5)) if rand_str not in unique_strings: unique_strings.add(rand_str) return list(unique_strings) -# TODO(T143955558): make n_int and metadatas as its input; -def get_random_config( +def get_random_test_suites( n_model_inputs: int, model_input_sizes: List[List[int]], n_model_outputs: int, model_output_sizes: List[List[int]], dtype: torch.dtype, n_sets_per_plan_test: int, - n_execution_plan_tests: int, + n_method_test_suites: int, ) -> Tuple[ List[str], - List[List[InputValues]], - List[List[OutputValues]], - BundledConfig, + List[List[MethodInputType]], + List[List[MethodOutputType]], + List[MethodTestSuite], ]: """Helper function to generate config filled with random inputs and expected outputs. @@ -139,65 +135,99 @@ def get_random_config( """ - rand_method_names = get_rand_method_names(n_execution_plan_tests) + rand_method_names = get_rand_method_names(n_method_test_suites) - rand_inputs = get_rand_input_values( + rand_inputs_per_program = get_rand_input_values( n_tensors=n_model_inputs, sizes=model_input_sizes, n_int=1, dtype=dtype, n_sets_per_plan_test=n_sets_per_plan_test, - n_execution_plan_tests=n_execution_plan_tests, + n_method_test_suites=n_method_test_suites, ) - rand_expected_outputs = get_rand_output_values( + rand_expected_output_per_program = get_rand_output_values( n_tensors=n_model_outputs, sizes=model_output_sizes, dtype=dtype, n_sets_per_plan_test=n_sets_per_plan_test, - n_execution_plan_tests=n_execution_plan_tests, + n_method_test_suites=n_method_test_suites, ) + rand_method_test_suites: List[MethodTestSuite] = [] + + for ( + rand_method_name, + rand_inputs_per_method, + rand_expected_output_per_method, + ) in zip( + rand_method_names, rand_inputs_per_program, rand_expected_output_per_program + ): + rand_method_test_cases: List[MethodTestCase] = [] + for rand_inputs, rand_expected_outputs in zip( + rand_inputs_per_method, rand_expected_output_per_method + ): + rand_method_test_cases.append( + MethodTestCase( + inputs=rand_inputs, expected_outputs=rand_expected_outputs + ) + ) + + rand_method_test_suites.append( + MethodTestSuite( + method_name=rand_method_name, test_cases=rand_method_test_cases + ) + ) + return ( rand_method_names, - rand_inputs, - rand_expected_outputs, - # pyre-ignore[6]: Expected Union[Tensor, int, float, bool] for each element in 2nd positional argument, but got Union[Tensor, int] - BundledConfig(rand_method_names, rand_inputs, rand_expected_outputs), + rand_inputs_per_program, + rand_expected_output_per_program, + rand_method_test_suites, ) -def get_random_config_with_eager_model( +def get_random_test_suites_with_eager_model( eager_model: torch.nn.Module, method_names: List[str], n_model_inputs: int, model_input_sizes: List[List[int]], dtype: torch.dtype, n_sets_per_plan_test: int, -) -> Tuple[List[List[InputValues]], BundledConfig]: +) -> Tuple[List[List[MethodInputType]], List[MethodTestSuite]]: """Generate config filled with random inputs for each inference method given eager model - The details of return type is the same as get_random_config_with_rand_io_lists. + The details of return type is the same as get_random_test_suites_with_rand_io_lists. """ - inputs = get_rand_input_values( + inputs_per_program = get_rand_input_values( n_tensors=n_model_inputs, sizes=model_input_sizes, n_int=1, dtype=dtype, n_sets_per_plan_test=n_sets_per_plan_test, - n_execution_plan_tests=len(method_names), + n_method_test_suites=len(method_names), ) - expected_outputs = [ - [[getattr(eager_model, m_name)(*x)] for x in inputs[i]] - for i, m_name in enumerate(method_names) - ] + method_test_suites: List[MethodTestSuite] = [] + + for method_name, inputs_per_method in zip(method_names, inputs_per_program): + method_test_cases: List[MethodTestCase] = [] + for inputs in inputs_per_method: + method_test_cases.append( + MethodTestCase( + inputs=inputs, + expected_outputs=getattr(eager_model, method_name)(*inputs), + ) + ) + + method_test_suites.append( + MethodTestSuite(method_name=method_name, test_cases=method_test_cases) + ) - # pyre-ignore[6]: Expected Union[Tensor, int, float, bool] for each element in 2nd positional argument, but got Union[Tensor, int] - return inputs, BundledConfig(method_names, inputs, expected_outputs) + return inputs_per_program, method_test_suites -def get_common_program() -> Tuple[Program, BundledConfig]: +def get_common_program() -> Tuple[Program, List[MethodTestSuite]]: """Helper function to generate a sample BundledProgram with its config.""" eager_model = SampleModel() # Trace to FX Graph. @@ -216,7 +246,7 @@ def get_common_program() -> Tuple[Program, BundledConfig]: .to_executorch() .program ) - _, bundled_config = get_random_config_with_eager_model( + _, method_test_suites = get_random_test_suites_with_eager_model( eager_model=eager_model, method_names=eager_model.method_names, n_model_inputs=2, @@ -224,4 +254,4 @@ def get_common_program() -> Tuple[Program, BundledConfig]: dtype=torch.int32, n_sets_per_plan_test=10, ) - return program, bundled_config + return program, method_test_suites diff --git a/bundled_program/tests/test_bundle_data.py b/bundled_program/tests/test_bundle_data.py index 3434e09c1a..105061dba1 100644 --- a/bundled_program/tests/test_bundle_data.py +++ b/bundled_program/tests/test_bundle_data.py @@ -41,64 +41,64 @@ def assertIOsetDataEqual( self.assertEqual(program_element.val.bool_val, config_element) def test_bundled_program(self) -> None: - program, bundled_config = get_common_program() + program, method_test_suites = get_common_program() - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) + + method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name) for plan_id in range(len(program.execution_plan)): - bundled_plan_test = bundled_program.execution_plan_tests[plan_id] - config_plan_test = bundled_config.execution_plan_tests[plan_id] + bundled_plan_test = bundled_program.method_test_suites[plan_id] + method_test_suite = method_test_suites[plan_id] self.assertEqual( - len(bundled_plan_test.test_sets), len(config_plan_test.test_sets) + len(bundled_plan_test.test_cases), len(method_test_suite.test_cases) ) - for bundled_program_ioset, bundled_config_ioset in zip( - bundled_plan_test.test_sets, config_plan_test.test_sets + for bundled_program_ioset, method_test_case in zip( + bundled_plan_test.test_cases, method_test_suite.test_cases ): self.assertIOsetDataEqual( - bundled_program_ioset.inputs, bundled_config_ioset.inputs + bundled_program_ioset.inputs, method_test_case.inputs ) self.assertIOsetDataEqual( bundled_program_ioset.expected_outputs, - bundled_config_ioset.expected_outputs, + method_test_case.expected_outputs, ) self.assertEqual(bundled_program.program, _serialize_pte_binary(program)) - def test_bundle_miss_methods(self) -> None: - program, bundled_config = get_common_program() + def test_bundled_miss_methods(self) -> None: + program, method_test_suites = get_common_program() # only keep the testcases for the first method to mimic the case that user only creates testcases for the first method. - bundled_config.execution_plan_tests = bundled_config.execution_plan_tests[:1] + method_test_suites = method_test_suites[:1] - _ = create_bundled_program(program, bundled_config) + _ = create_bundled_program(program, method_test_suites) - def test_bundle_wrong_method_name(self) -> None: - program, bundled_config = get_common_program() + def test_bundled_wrong_method_name(self) -> None: + program, method_test_suites = get_common_program() - bundled_config.execution_plan_tests[-1].method_name = "wrong_method_name" + method_test_suites[-1].method_name = "wrong_method_name" self.assertRaises( - AssertionError, create_bundled_program, program, bundled_config + AssertionError, create_bundled_program, program, method_test_suites ) def test_bundle_wrong_input_type(self) -> None: - program, bundled_config = get_common_program() + program, method_test_suites = get_common_program() - # pyre-ignore[8]: Use a wrong type on purpose. Should raise an error when creating a bundled program using bundled_config. - bundled_config.execution_plan_tests[-1].test_sets[-1].inputs = [ - "WRONG INPUT TYPE" - ] + # pyre-ignore[8]: Use a wrong type on purpose. Should raise an error when creating a bundled program using method_test_suites. + method_test_suites[0].test_cases[-1].inputs = ["WRONG INPUT TYPE"] self.assertRaises( - AssertionError, create_bundled_program, program, bundled_config + AssertionError, create_bundled_program, program, method_test_suites ) def test_bundle_wrong_output_type(self) -> None: - program, bundled_config = get_common_program() + program, method_test_suites = get_common_program() - bundled_config.execution_plan_tests[-1].test_sets[-1].expected_outputs = [ + method_test_suites[0].test_cases[-1].expected_outputs = [ 0, 0.0, ] self.assertRaises( - AssertionError, create_bundled_program, program, bundled_config + AssertionError, create_bundled_program, program, method_test_suites ) diff --git a/bundled_program/tests/test_config.py b/bundled_program/tests/test_config.py index f1bc4e271a..77cc491585 100644 --- a/bundled_program/tests/test_config.py +++ b/bundled_program/tests/test_config.py @@ -13,8 +13,8 @@ from executorch.bundled_program.config import DataContainer from executorch.bundled_program.tests.common import ( - get_random_config, - get_random_config_with_eager_model, + get_random_test_suites, + get_random_test_suites_with_eager_model, SampleModel, ) from executorch.extension.pytree import tree_flatten @@ -39,59 +39,55 @@ def assertIOListEqual( else: self.assertTrue(t1 == t2) - def test_create_config(self) -> None: + def test_create_test_suites(self) -> None: n_sets_per_plan_test = 10 - n_execution_plan_tests = 5 + n_method_test_suites = 5 ( rand_method_names, rand_inputs, rand_expected_outpus, - bundled_config, - ) = get_random_config( + method_test_suites, + ) = get_random_test_suites( n_model_inputs=2, model_input_sizes=[[2, 2], [2, 2]], n_model_outputs=1, model_output_sizes=[[2, 2]], dtype=torch.int32, n_sets_per_plan_test=n_sets_per_plan_test, - n_execution_plan_tests=n_execution_plan_tests, + n_method_test_suites=n_method_test_suites, ) - self.assertEqual( - len(bundled_config.execution_plan_tests), n_execution_plan_tests - ) - - rand_method_names.sort() + self.assertEqual(len(method_test_suites), n_method_test_suites) # Compare to see if bundled execution plan test match expectations. - for plan_test_idx in range(n_execution_plan_tests): + for method_test_suite_idx in range(n_method_test_suites): self.assertEqual( - bundled_config.execution_plan_tests[plan_test_idx].method_name, - rand_method_names[plan_test_idx], + method_test_suites[method_test_suite_idx].method_name, + rand_method_names[method_test_suite_idx], ) for testset_idx in range(n_sets_per_plan_test): self.assertIOListEqual( - # pyre-ignore - rand_inputs[plan_test_idx][testset_idx], - bundled_config.execution_plan_tests[plan_test_idx] - .test_sets[testset_idx] + # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] + rand_inputs[method_test_suite_idx][testset_idx], + method_test_suites[method_test_suite_idx] + .test_cases[testset_idx] .inputs, ) self.assertIOListEqual( - # pyre-ignore - rand_expected_outpus[plan_test_idx][testset_idx], - bundled_config.execution_plan_tests[plan_test_idx] - .test_sets[testset_idx] + # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] + rand_expected_outpus[method_test_suite_idx][testset_idx], + method_test_suites[method_test_suite_idx] + .test_cases[testset_idx] .expected_outputs, ) - def test_create_config_from_eager_model(self) -> None: + def test_create_test_suites_from_eager_model(self) -> None: n_sets_per_plan_test = 10 eager_model = SampleModel() method_names: List[str] = eager_model.method_names - rand_inputs, bundled_config = get_random_config_with_eager_model( + rand_inputs, method_test_suites = get_random_test_suites_with_eager_model( eager_model=eager_model, method_names=method_names, n_model_inputs=2, @@ -100,28 +96,26 @@ def test_create_config_from_eager_model(self) -> None: n_sets_per_plan_test=n_sets_per_plan_test, ) - self.assertEqual(len(bundled_config.execution_plan_tests), len(method_names)) - - sorted_method_names = sorted(method_names) + self.assertEqual(len(method_test_suites), len(method_names)) # Compare to see if bundled testcases match expectations. - for plan_test_idx in range(len(method_names)): + for method_test_suite_idx in range(len(method_names)): self.assertEqual( - bundled_config.execution_plan_tests[plan_test_idx].method_name, - sorted_method_names[plan_test_idx], + method_test_suites[method_test_suite_idx].method_name, + method_names[method_test_suite_idx], ) for testset_idx in range(n_sets_per_plan_test): - ri = rand_inputs[plan_test_idx][testset_idx] + ri = rand_inputs[method_test_suite_idx][testset_idx] self.assertIOListEqual( - # pyre-ignore[6] + # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] ri, - bundled_config.execution_plan_tests[plan_test_idx] - .test_sets[testset_idx] + method_test_suites[method_test_suite_idx] + .test_cases[testset_idx] .inputs, ) model_outputs = getattr( - eager_model, sorted_method_names[plan_test_idx] + eager_model, method_names[method_test_suite_idx] )(*ri) if isinstance(model_outputs, get_args(DataContainer)): # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. @@ -133,7 +127,7 @@ def test_create_config_from_eager_model(self) -> None: self.assertIOListEqual( flatten_eager_model_outputs, - bundled_config.execution_plan_tests[plan_test_idx] - .test_sets[testset_idx] + method_test_suites[method_test_suite_idx] + .test_cases[testset_idx] .expected_outputs, ) diff --git a/bundled_program/tests/test_end2end.py b/bundled_program/tests/test_end2end.py index ffb0462525..443ff9968f 100644 --- a/bundled_program/tests/test_end2end.py +++ b/bundled_program/tests/test_end2end.py @@ -20,7 +20,6 @@ import executorch.extension.pytree as pytree import torch -from executorch.bundled_program.config import BundledConfig from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( @@ -62,10 +61,10 @@ class BundledProgramE2ETest(unittest.TestCase): def test_sample_model_e2e(self): - program, bundled_config = get_common_program() + program, method_test_suites = get_common_program() eager_model = SampleModel() - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( bundled_program diff --git a/bundled_program/tests/test_serialize.py b/bundled_program/tests/test_serialize.py index b4882f381d..4ddf68af2b 100644 --- a/bundled_program/tests/test_serialize.py +++ b/bundled_program/tests/test_serialize.py @@ -20,9 +20,9 @@ class TestSerialize(unittest.TestCase): def test_bundled_program_serialization(self) -> None: - program, bundled_config = get_common_program() + program, method_test_suites = get_common_program() - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) pretty_print(bundled_program) flat_buffer_bundled_program = serialize_from_bundled_program_to_flatbuffer( bundled_program diff --git a/docs/source/sdk-bundled-io.md b/docs/source/sdk-bundled-io.md index 00522bbae6..c543441bab 100644 --- a/docs/source/sdk-bundled-io.md +++ b/docs/source/sdk-bundled-io.md @@ -1,7 +1,7 @@ # Bundled Program -- a Tool for ExecuTorch Model Validation ## Introduction -BundledProgram is a wrapper around the core ExecuTorch program designed to help users wrapping test cases with the model they deploy. BundledProgram is not necessarily a core part of the program and not needed for its execution, but is particularly important for various other use-cases, such as model correctness evaluation, including e2e testing during the model bring-up process. +`BundledProgram` is a wrapper around the core ExecuTorch program designed to help users wrapping test cases with the model they deploy. `BundledProgram` is not necessarily a core part of the program and not needed for its execution, but is particularly important for various other use-cases, such as model correctness evaluation, including e2e testing during the model bring-up process. Overall, the procedure can be broken into two stages, and in each stage we are supporting: @@ -11,31 +11,40 @@ Overall, the procedure can be broken into two stages, and in each stage we are s ## Emit stage This stage mainly focuses on the creation of a `BundledProgram` and dumping it out to the disk as a flatbuffer file. The main procedure is as follow: 1. Create a model and emit its ExecuTorch program. -2. Construct a `BundledConfig` to record all info that needs to be bundled. -3. Generate `BundledProgram` by using the emited model and `BundledConfig`. +2. Construct a `List[MethodTestSuite]` to record all test cases that needs to be bundled. +3. Generate `BundledProgram` by using the emited model and `List[MethodTestSuite]`. 4. Serialize the `BundledProgram` and dump it out to the disk. ### Step 1: Create a Model and Emit its ExecuTorch Program. ExecuTorch Program can be emitted from user's model by using ExecuTorch APIs. Follow the [Generate Sample ExecuTorch program](./getting-started-setup.md) or [Exporting to ExecuTorch tutorial](./tutorials/export-to-executorch-tutorial). -### Step 2: Construct `BundledConfig` +### Step 2: Construct `List[MethodTestSuite]` to hold test info +In `BundledProgram`, we create two new classes, `MethodTestCase` and `MethodTestSuite`, to hold essential info for ExecuTorch program verification. -`BundledConfig` is a class under `executorch/bundled_program/config.py` that contains all information to be bundled for model verification. Here's the constructor api to create `BundledConfig`: +:::{dropdown} `MethodTestCase` -:::{dropdown} `BundledConfig` +```{eval-rst} +.. autofunction:: bundled_program.config.MethodTestCase.__init__ + :noindex: +``` +::: + +:::{dropdown} `MethodTestSuite` ```{eval-rst} -.. autofunction:: bundled_program.config.BundledConfig.__init__ +.. autofunction:: bundled_program.config.MethodTestSuite :noindex: ``` ::: +Since each model may have multiple inference methods, we need to generate `List[MethodTestSuite]` to hold all essential infos. + ### Step 3: Generate `BundledProgram` -We provide `create_bundled_program` API under `executorch/bundled_program/core.py` to generate `BundledProgram` by bundling the emitted ExecuTorch program with the bundled_config: +We provide `create_bundled_program` API under `executorch/bundled_program/core.py` to generate `BundledProgram` by bundling the emitted ExecuTorch program with the `List[MethodTestSuite]`: :::{dropdown} `BundledProgram` @@ -46,8 +55,8 @@ We provide `create_bundled_program` API under `executorch/bundled_program/core.p ``` ::: -`create_bundled_program` will do sannity check internally to see if the given BundledConfig matches the given Program's requirements. Specifically: -1. The name of methods we create BundledConfig for should be also in program. Please notice that it is no need to set testcases for every method in the Program. +`create_bundled_program` will do sannity check internally to see if the given `List[MethodTestSuite]` matches the given Program's requirements. Specifically: +1. The method_names of each `MethodTestSuite` in `List[MethodTestSuite]` for should be also in program. Please notice that it is no need to set testcases for every method in the Program. 2. The metadata of each testcase should meet the requirement of the coresponding inference methods input. ### Step 4: Serialize `BundledProgram` to Flatbuffer. @@ -76,14 +85,15 @@ Here is a flow highlighting how to generate a `BundledProgram` given a PyTorch m ```python import torch -from torch.export import export -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program -from executorch.bundled_program.serialize import serialize_from_bundled_program_to_flatbuffer +from executorch.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) from executorch.exir import to_edge - +from torch.export import export # Step 1: ExecuTorch Program Export @@ -95,9 +105,7 @@ class SampleModel(torch.nn.Module): self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32) self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32) - def encode( - self, x: torch.Tensor, q: torch.Tensor - ) -> torch.Tensor: + def encode(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor: z = x.clone() torch.mul(self.a, x, out=z) y = x.clone() @@ -105,9 +113,7 @@ class SampleModel(torch.nn.Module): torch.add(y, q, out=y) return y - def decode( - self, x: torch.Tensor, q: torch.Tensor - ) -> torch.Tensor: + def decode(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor: y = x * q torch.add(y, self.b, out=y) return y @@ -134,15 +140,18 @@ method_graphs = { # Emit the traced methods into ET Program. program = to_edge(method_graphs).to_executorch().executorch_program -# Step 2: Construct BundledConfig +# Step 2: Construct MethodTestSuite for Each Method + +# Prepare the Test Inputs. # number of input sets to be verified n_input = 10 # Input sets to be verified for each inference methods. -inputs = [ - # The below list is all inputs for a single inference method. - [ +# To simplify, here we create same inputs for all methods. +inputs = { + # Inference method name corresponding to its test cases. + m_name: [ # Each list below is a individual input set. # The number of inputs, dtype and size of each input follow Program's spec. [ @@ -151,27 +160,32 @@ inputs = [ ] for _ in range(n_input) ] - for _ in method_names -] + for m_name in method_names +} -# Expected outputs align with inputs. -expected_outputs = [ - [[getattr(model, m_name)(*x)] for x in inputs[i]] - for i, m_name in enumerate(method_names) +# Generate Test Suites +method_test_suites = [ + MethodTestSuite( + method_name=m_name, + test_cases=[ + MethodTestCase( + inputs=input, + expected_outputs=getattr(model, m_name)(*input), + ) + for input in inputs[m_name] + ], + ) + for m_name in method_names ] -# Create BundledConfig -bundled_config = BundledConfig( - method_names, inputs, expected_outputs -) - - # Step 3: Generate BundledProgram -bundled_program = create_bundled_program(program, bundled_config) +bundled_program = create_bundled_program(program, method_test_suites) # Step 4: Serialize BundledProgram to flatbuffer. -serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(bundled_program) +serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer( + bundled_program +) save_path = "bundled_program.bpte" with open(save_path, "wb") as f: f.write(serialized_bundled_program) @@ -299,7 +313,7 @@ ET_CHECK_MSG( ## Common Errors -Errors will be raised if `BundledConfig` doesn't match the `Program`. Here're two common situations: +Errors will be raised if `List[MethodTestSuites]` doesn't match the `Program`. Here're two common situations: ### Test input doesn't match model's requirement. @@ -309,12 +323,12 @@ Here's the example of the dtype of test input not meet model's requirement: ```python import torch -from torch.export import export -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.exir import to_edge +from torch.export import export class Module(torch.nn.Module): @@ -332,96 +346,107 @@ class Module(torch.nn.Module): model = Module() -method_names = ['forward'] +method_names = ["forward"] inputs = torch.ones(2, 2, dtype=torch.float) -print(model(inputs)) # Find each method of model needs to be traced my its name, export its FX Graph. method_graphs = { - m_name: export(getattr(model, m_name), (inputs, )) - for m_name in method_names + m_name: export(getattr(model, m_name), (inputs,)) for m_name in method_names } # Emit the traced methods into ET Program. program = to_edge(method_graphs).to_executorch().executorch_program - # number of input sets to be verified n_input = 10 -# All Input sets to be verified. -inputs = [ - [ +# Input sets to be verified for each inference methods. +# To simplify, here we create same inputs for all methods. +inputs = { + # Inference method name corresponding to its test cases. + m_name: [ # NOTE: executorch program needs torch.float, but here is torch.int [ torch.randint(-5, 5, (2, 2), dtype=torch.int), ] for _ in range(n_input) ] -] + for m_name in method_names +} -# Expected outputs align with inputs. -expected_outpus = [ - [[model(*x)] for x in inputs[0]] +# Generate Test Suites +method_test_suites = [ + MethodTestSuite( + method_name=m_name, + test_cases=[ + MethodTestCase( + inputs=input, + expected_outputs=getattr(model, m_name)(*input), + ) + for input in inputs[m_name] + ], + ) + for m_name in method_names ] -bundled_config = BundledConfig(method_names, inputs, expected_outpus) +# Generate BundledProgram -bundled_program = create_bundled_program(program, bundled_config) +bundled_program = create_bundled_program(program, method_test_suites) ``` :::{dropdown} Raised Error ``` -The input tensor tensor([[ 0, 3], - [-3, -3]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32 +The input tensor tensor([[-2, 0], + [-2, -1]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32 --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) - 57 expected_outpus = [ - 58 [[model(*x)] for x in inputs[0]] - 59 ] - 61 bundled_config = BundledConfig(method_names, inputs, expected_outpus) ----> 63 bundled_program = create_bundled_program(program, bundled_config) -File /executorch/bundled_program/core.py:270, in create_bundled_program(program, bundled_config) - 259 def create_bundled_program( - 260 program: Program, - 261 bundled_config: BundledConfig, - 262 ) -> BundledProgram: - 263 """Create BundledProgram by bundling the given program and bundled_config together. - 264 - 265 Args: - 266 program: The program to be bundled. - 267 bundled_config: The config to be bundled. - 268 """ ---> 270 assert_valid_bundle(program, bundled_config) - 272 execution_plan_tests: List[BundledExecutionPlanTest] = [] - 274 # Emit data and metadata of bundled tensor -File /executorch/bundled_program/core.py:224, in assert_valid_bundle(program, bundled_config) - 220 # type of tensor input should match execution plan - 221 if type(cur_plan_test_inputs[j]) == torch.Tensor: - 222 # pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]` - 223 # has no attribute `dtype`. ---> 224 assert cur_plan_test_inputs[j].dtype == get_input_dtype( - 225 program, program_plan_id, j - 226 ), "The input tensor {} dtype shall be {}, but now is {}".format( - 227 cur_plan_test_inputs[j], - 228 get_input_dtype(program, program_plan_id, j), - 229 cur_plan_test_inputs[j].dtype, - 230 ) - 231 elif type(cur_plan_test_inputs[j]) in ( - 232 int, - 233 bool, - 234 float, - 235 ): - 236 assert type(cur_plan_test_inputs[j]) == get_input_type( - 237 program, program_plan_id, j - 238 ), "The input primitive dtype shall be {}, but now is {}".format( - 239 get_input_type(program, program_plan_id, j), - 240 type(cur_plan_test_inputs[j]), - 241 ) -AssertionError: The input tensor tensor([[ 0, 3], - [-3, -3]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32 +Cell In[1], line 72 + 56 method_test_suites = [ + 57 MethodTestSuite( + 58 method_name=m_name, + (...) + 67 for m_name in method_names + 68 ] + 70 # Step 3: Generate BundledProgram +---> 72 bundled_program = create_bundled_program(program, method_test_suites) +File /executorch/bundled_program/core.py:276, in create_bundled_program(program, method_test_suites) + 264 """Create bp_schema.BundledProgram by bundling the given program and method_test_suites together. + 265 + 266 Args: + (...) + 271 The `BundledProgram` variable contains given ExecuTorch program and test cases. + 272 """ + 274 method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name) +--> 276 assert_valid_bundle(program, method_test_suites) + 278 bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = [] + 280 # Emit data and metadata of bundled tensor +File /executorch/bundled_program/core.py:219, in assert_valid_bundle(program, method_test_suites) + 215 # type of tensor input should match execution plan + 216 if type(cur_plan_test_inputs[j]) == torch.Tensor: + 217 # pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]` + 218 # has no attribute `dtype`. +--> 219 assert cur_plan_test_inputs[j].dtype == get_input_dtype( + 220 program, program_plan_id, j + 221 ), "The input tensor {} dtype shall be {}, but now is {}".format( + 222 cur_plan_test_inputs[j], + 223 get_input_dtype(program, program_plan_id, j), + 224 cur_plan_test_inputs[j].dtype, + 225 ) + 226 elif type(cur_plan_test_inputs[j]) in ( + 227 int, + 228 bool, + 229 float, + 230 ): + 231 assert type(cur_plan_test_inputs[j]) == get_input_type( + 232 program, program_plan_id, j + 233 ), "The input primitive dtype shall be {}, but now is {}".format( + 234 get_input_type(program, program_plan_id, j), + 235 type(cur_plan_test_inputs[j]), + 236 ) +AssertionError: The input tensor tensor([[-2, 0], + [-2, -1]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32 ``` @@ -429,17 +454,16 @@ AssertionError: The input tensor tensor([[ 0, 3], ### Method name in `BundleConfig` does not exist. -Another common error would be the method name in `BundledConfig` does not exist in Model. `BundledProgram` will raise error and show the non-exist method name: +Another common error would be the method name in any `MethodTestSuite` does not exist in Model. `BundledProgram` will raise error and show the non-exist method name: ```python import torch -from torch.export import export -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.exir import to_edge - +from torch.export import export class Module(torch.nn.Module): @@ -458,83 +482,87 @@ class Module(torch.nn.Module): model = Module() -method_names = ['forward'] +method_names = ["forward"] inputs = torch.ones(2, 2, dtype=torch.float) # Find each method of model needs to be traced my its name, export its FX Graph. method_graphs = { - m_name: export(getattr(model, m_name), (inputs, )) - for m_name in method_names + m_name: export(getattr(model, m_name), (inputs,)) for m_name in method_names } # Emit the traced methods into ET Program. program = to_edge(method_graphs).to_executorch().executorch_program -# Number of input sets to be verified +# number of input sets to be verified n_input = 10 -# All Input sets to be verified. -inputs = [ - [ +# Input sets to be verified for each inference methods. +# To simplify, here we create same inputs for all methods. +inputs = { + # Inference method name corresponding to its test cases. + m_name: [ [ torch.randint(-5, 5, (2, 2), dtype=torch.float), ] for _ in range(n_input) ] -] + for m_name in method_names +} -# Expected outputs align with inputs. -expected_outpus = [ - [[model(*x)] for x in inputs[0]] +# Generate Test Suites +method_test_suites = [ + MethodTestSuite( + method_name=m_name, + test_cases=[ + MethodTestCase( + inputs=input, + expected_outputs=getattr(model, m_name)(*input), + ) + for input in inputs[m_name] + ], + ) + for m_name in method_names ] - # NOTE: MISSING_METHOD_NAME is not an inference method in the above model. -wrong_method_names = ['MISSING_METHOD_NAME'] - -bundled_config = BundledConfig(wrong_method_names, inputs, expected_outpus) +method_test_suites[0].method_name = "MISSING_METHOD_NAME" -bundled_program = create_bundled_program(program, bundled_config) +# Generate BundledProgram +bundled_program = create_bundled_program(program, method_test_suites) ``` :::{dropdown} Raised Error ``` -All method names in bundled config should be found in program.execution_plan, but {'wrong_forward'} does not include. +All method names in bundled config should be found in program.execution_plan, but {'MISSING_METHOD_NAME'} does not include. --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) - 58 expected_outpus = [ - 59 [[model(*x)] for x in inputs[0]] - 60 ] - 62 bundled_config = BundledConfig(method_names, inputs, expected_outpus) ----> 64 bundled_program = create_bundled_program(program, bundled_config) -File /executorch/bundled_program/core.py:270, in create_bundled_program(program, bundled_config) - 259 def create_bundled_program( - 260 program: Program, - 261 bundled_config: BundledConfig, - 262 ) -> BundledProgram: - 263 """Create BundledProgram by bundling the given program and bundled_config together. - 264 - 265 Args: - 266 program: The program to be bundled. - 267 bundled_config: The config to be bundled. - 268 """ ---> 270 assert_valid_bundle(program, bundled_config) - 272 execution_plan_tests: List[BundledExecutionPlanTest] = [] - 274 # Emit data and metadata of bundled tensor -File /executorch/bundled_program/core.py:147, in assert_valid_bundle(program, bundled_config) - 142 method_name_of_program = {e.name for e in program.execution_plan} - 143 method_name_of_bundled_config = { - 144 t.method_name for t in bundled_config.execution_plan_tests - 145 } ---> 147 assert method_name_of_bundled_config.issubset( - 148 method_name_of_program - 149 ), f"All method names in bundled config should be found in program.execution_plan, \ - 150 but {str(method_name_of_bundled_config - method_name_of_program)} does not include." - 152 # check if has been sorted in ascending alphabetical order of method name. - 153 for bp_plan_id in range(1, len(bundled_config.execution_plan_tests)): +Cell In[3], line 73 + 70 method_test_suites[0].method_name = "MISSING_METHOD_NAME" + 72 # Generate BundledProgram +---> 73 bundled_program = create_bundled_program(program, method_test_suites) +File /executorch/bundled_program/core.py:276, in create_bundled_program(program, method_test_suites) + 264 """Create bp_schema.BundledProgram by bundling the given program and method_test_suites together. + 265 + 266 Args: + (...) + 271 The `BundledProgram` variable contains given ExecuTorch program and test cases. + 272 """ + 274 method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name) +--> 276 assert_valid_bundle(program, method_test_suites) + 278 bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = [] + 280 # Emit data and metadata of bundled tensor +File /executorch/bundled_program/core.py:141, in assert_valid_bundle(program, method_test_suites) + 138 method_name_of_program = {e.name for e in program.execution_plan} + 139 method_name_of_test_suites = {t.method_name for t in method_test_suites} +--> 141 assert method_name_of_test_suites.issubset( + 142 method_name_of_program + 143 ), f"All method names in bundled config should be found in program.execution_plan, \ + 144 but {str(method_name_of_test_suites - method_name_of_program)} does not include." + 146 # check if method_tesdt_suites has been sorted in ascending alphabetical order of method name. + 147 for test_suite_id in range(1, len(method_test_suites)): AssertionError: All method names in bundled config should be found in program.execution_plan, but {'MISSING_METHOD_NAME'} does not include. ``` ::: diff --git a/docs/source/tutorials_source/sdk-integration-tutorial.py b/docs/source/tutorials_source/sdk-integration-tutorial.py index 2e952b017d..fd37b52b27 100644 --- a/docs/source/tutorials_source/sdk-integration-tutorial.py +++ b/docs/source/tutorials_source/sdk-integration-tutorial.py @@ -130,7 +130,7 @@ def forward(self, x): import torch -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, @@ -143,14 +143,22 @@ def forward(self, x): m_name = "forward" method_graphs = {m_name: export(getattr(model, m_name), (torch.randn(1, 1, 32, 32),))} -# Step 2: Construct BundledConfig +# Step 2: Construct Method Test Suites inputs = [[torch.randn(1, 1, 32, 32)] for _ in range(2)] -expected_outputs = [[[getattr(model, m_name)(*x)] for x in inputs]] -bundled_config = BundledConfig([m_name], [inputs], expected_outputs) + +method_test_suites = [ + MethodTestSuite( + method_name=m_name, + test_cases=[ + MethodTestCase(inputs=inp, outputs=getattr(model, m_name)(*inp)) + for inp in inputs + ], + ) +] # Step 3: Generate BundledProgram program = to_edge(method_graphs).to_executorch().executorch_program -bundled_program = create_bundled_program(program, bundled_config) +bundled_program = create_bundled_program(program, method_test_suites) # Step 4: Serialize BundledProgram to flatbuffer. serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer( diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index b30cb0eea1..bb0c1b68b5 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -13,14 +13,14 @@ import torch._export as export from executorch import exir from executorch.backends.apple.mps.mps_preprocess import MPSBackend -from executorch.bundled_program.config import BundledConfig + +from executorch.exir.backend.backend_api import to_backend +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from executorch.exir.backend.backend_api import to_backend - from ....models import MODEL_NAME_TO_MODEL from ....models.model_factory import EagerModelFactory @@ -91,19 +91,19 @@ def forward(self, *input_args): model_name = f"{args.model_name}_mps" if args.bundled: - bundled_inputs = [ - [example_inputs] - for _ in range(len(executorch_program.program.execution_plan)) - ] - - output = model(*example_inputs) - expected_outputs = [ - [[output]] for _ in range(len(executorch_program.program.execution_plan)) + method_test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase( + inputs=example_inputs, expected_outputs=[model(*example_inputs)] + ) + ], + ) ] - bundled_config = BundledConfig(["forward"], bundled_inputs, expected_outputs) bundled_program = create_bundled_program( - executorch_program.program, bundled_config + executorch_program.program, method_test_suites ) bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( bundled_program diff --git a/examples/sdk/scripts/export_bundled_program.py b/examples/sdk/scripts/export_bundled_program.py index 50702896f0..1171a216b7 100644 --- a/examples/sdk/scripts/export_bundled_program.py +++ b/examples/sdk/scripts/export_bundled_program.py @@ -7,15 +7,14 @@ # Example script for exporting simple models to flatbuffer import argparse -import os from typing import List import torch from executorch.bundled_program.config import ( - BundledConfig, MethodInputType, - MethodOutputType, + MethodTestCase, + MethodTestSuite, ) from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( @@ -30,27 +29,19 @@ def save_bundled_program( program: Program, - method_names: List[str], - bundled_inputs: List[List[MethodInputType]], - bundled_expected_outputs: List[List[MethodOutputType]], + method_test_suites: List[MethodTestSuite], output_path: str, -) -> None: +): """ Generates a bundled program from the given ET program and saves it to the specified path. Args: program: The ExecuTorch program to bundle. - method_names: A list of method names in the program to bundle test cases. - bundled_inputs: Representative inputs for each method. - bundled_expected_outputs: Expected outputs of representative inputs for each method. + method_test_suites: The MethodTestSuites which contains test cases to include in the bundled program. output_path: Path to save the bundled program. """ - bundled_config = BundledConfig( - method_names, bundled_inputs, bundled_expected_outputs - ) - - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( bundled_program ) @@ -84,29 +75,38 @@ def export_to_bundled_program( print("Creating bundled test cases...") method_names = [method.name for method in program.execution_plan] - # A model could have multiple entry point methods and each of them can have inputs bundled for testing. - # This example demonstrates a model which has a single entry point method ("forward") to which we want - # to bundle two input test cases (example_inputs is used two times) for the "forward" method. - bundled_inputs = [ - [example_inputs, example_inputs] for i in range(len(method_names)) - ] - - bundled_expected_outputs = [ - [[getattr(model, method_names[i])(*x)] for x in bundled_inputs[i]] - for i in range(len(program.execution_plan)) - ] - - bundled_program_name = f"{model_name}_bundled.bpte" - output_path = os.path.join(output_directory, bundled_program_name) - - print(f"Saving exported program to {output_path}") - save_bundled_program( - program=program, - method_names=method_names, - bundled_inputs=bundled_inputs, - bundled_expected_outputs=bundled_expected_outputs, - output_path=output_path, - ) + # A model could have multiple entry point methods and each of them can have multiple inputs bundled for testing. + # This example demonstrates a model which has multiple entry point methods, whose name listed in method_names, to which we want + # to bundle two input test cases (example_inputs is used two times) for each inference method. + program_inputs = { + m_name: [example_inputs, example_inputs] for m_name in method_names + } + + method_test_suites: List[MethodTestSuite] = [] + for m_name in method_names: + method_inputs = program_inputs[m_name] + + # To create a bundled program, we first create every test cases from input. We leverage eager model + # to generate expected output for each test input, and use MethodTestCase to hold the information of + # each test case. We gather all MethodTestCase for same method into one MethodTestSuite, and generate + # bundled program by all MethodTestSuites. + method_test_cases: List[MethodTestCase] = [] + for method_input in method_inputs: + method_test_cases.append( + MethodTestCase( + inputs=method_input, + expected_outputs=model(*method_input), + ) + ) + + method_test_suites.append( + MethodTestSuite( + method_name=m_name, + test_cases=method_test_cases, + ) + ) + + save_bundled_program(program, method_test_suites, f"{model_name}_bundled.bpte") if __name__ == "__main__": diff --git a/schema/bundled_program_schema.fbs b/schema/bundled_program_schema.fbs index 66f622a6a6..a80269b6d1 100644 --- a/schema/bundled_program_schema.fbs +++ b/schema/bundled_program_schema.fbs @@ -9,7 +9,7 @@ include "scalar_type.fbs"; namespace bundled_program_flatbuffer; // Identifier of a valid bundled program schema. -file_identifier "BP06"; +file_identifier "BP07"; // Extension of written files. file_extension "bpte"; @@ -44,45 +44,44 @@ union ValueUnion { Double, } -// Abstraction for BundledIOSet values +// Abstraction for BundledMethodTestCase values table Value { val: ValueUnion; } -// All inputs and referenced outputs needs for single verification. -table BundledIOSet { - // All inputs required by Program for execution. Its length should be - // equal to the length of program inputs. +// A single test for a method. The provided inputs should produce the +// expected outputs. +table BundledMethodTestCase { + // The inputs to provide to the method. The number and types of inputs must + // match the schema of the method under test. inputs: [Value]; - // The expected outputs generated while running the model in eager mode - // using the inputs provided. Its length should be equal to the length - // of program outputs. + // The expected outputs generated while running the model in eager mode using + // the inputs provided. Its length should be equal to the length of program + // outputs. expected_outputs: [Value]; } - -// Context for testing and verifying an exceution plan. -table BundledExecutionPlanTest { - +// Collection of test cases for a program method. +table BundledMethodTestSuite { // The name of the method to test; e.g., "forward" for the forward() method // of an nn.Module. This name match a method defined by the ExecuTorch // program. method_name: string; - // Sets of input/outputs to test with. - test_sets: [BundledIOSet]; + // Individual test cases for the method. + test_cases: [BundledMethodTestCase]; } + // Executorch program bunlded with data for verification. table BundledProgram { // Schema version. version:uint; - // Test sets and other meta datas to verify the whole program. - // Each BundledExecutionPlanTest should be used for the execution plan of program sharing same index. - // Its length should be equal to the number of execution plans in program. - execution_plan_tests: [BundledExecutionPlanTest]; + // Test sets to run against the program. + // Each BundledMethodTestSuite should be used for the method of program sharing same name. + method_test_suites: [BundledMethodTestSuite]; // The binary data of a serialized Executorch program. // The following `force_align` may sliently override any larger force_align @@ -91,12 +90,8 @@ table BundledProgram { // executorch program keeps the same alignment as original no matter how // the program schema changes, we need to make the force_align here the max // one around all kinds of force_align in the current and future program - // schema, so we use the 16, the largest possible alignment of flatbuffer, - // as the force_align here. - // In the future, we may need to revisit that to enforce larger alignment - // constraint. If needed, check against FLATBUFFERS_MAX_ALIGNMENT in the - // flatbuffers/base.h, which is the given alignment ceiling of flatbuffer. - program: [ubyte] (force_align: 16); + // schema, so we use the 4096 as the force_align here. + program: [ubyte] (force_align: 4096); } root_type BundledProgram; diff --git a/test/models/generate_linear_out_bundled_program.py b/test/models/generate_linear_out_bundled_program.py index 91cf853f50..c1538cf11d 100644 --- a/test/models/generate_linear_out_bundled_program.py +++ b/test/models/generate_linear_out_bundled_program.py @@ -14,11 +14,12 @@ """ import subprocess +from typing import List import executorch.exir as exir import torch -from executorch.bundled_program.config import BundledConfig +from executorch.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.bundled_program.core import create_bundled_program from executorch.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, @@ -55,29 +56,18 @@ def main() -> None: # Serialize to flatbuffer. program.version = 0 - bundled_inputs = [ - [ - [ - torch.rand(2, 2, dtype=torch.float), - ] - for _ in range(10) + # Create test sets + method_test_cases: List[MethodTestCase] = [] + for _ in range(10): + x = [ + torch.rand(2, 2, dtype=torch.float), ] - for _ in range(len(program.execution_plan)) + method_test_cases.append(MethodTestCase(inputs=x, expected_outputs=model(*x))) + method_test_suites = [ + MethodTestSuite(method_name="forward", test_cases=method_test_cases) ] - bundled_expected_outputs = [ - [[model(*x)] for x in bundled_inputs[i]] - for i in range(len(program.execution_plan)) - ] - - bundled_config = BundledConfig( - method_names=["forward"], - # pyre-ignore - inputs=bundled_inputs, - expected_outputs=bundled_expected_outputs, - ) - - bundled_program = create_bundled_program(program, bundled_config) + bundled_program = create_bundled_program(program, method_test_suites) pretty_print(bundled_program) bundled_program_flatbuffer = serialize_from_bundled_program_to_flatbuffer( diff --git a/util/bundled_program_verification.cpp b/util/bundled_program_verification.cpp index bc1c0991eb..5ff9713971 100644 --- a/util/bundled_program_verification.cpp +++ b/util/bundled_program_verification.cpp @@ -165,12 +165,13 @@ bool tensors_are_close( } } -Result get_method_test( +Result +get_method_test_suite( const bundled_program_flatbuffer::BundledProgram* bundled_program, const char* method_name) { - auto method_tests = bundled_program->execution_plan_tests(); - for (size_t i = 0; i < method_tests->size(); i++) { - auto m_test = method_tests->GetMutableObject(i); + auto method_test_suites = bundled_program->method_test_suites(); + for (size_t i = 0; i < method_test_suites->size(); i++) { + auto m_test = method_test_suites->GetMutableObject(i); if (std::strcmp(m_test->method_name()->c_str(), method_name) == 0) { return m_test; } @@ -194,7 +195,7 @@ __ET_NODISCARD Error LoadBundledInput( NotSupported, "The input buffer should be a bundled program."); - auto method_test = get_method_test( + auto method_test = get_method_test_suite( bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr), method_name); @@ -203,7 +204,7 @@ __ET_NODISCARD Error LoadBundledInput( } auto bundled_inputs = - method_test.get()->test_sets()->Get(testset_idx)->inputs(); + method_test.get()->test_cases()->Get(testset_idx)->inputs(); for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) { auto bundled_input = bundled_inputs->GetMutableObject(input_idx); @@ -289,7 +290,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( NotSupported, "The input buffer should be a bundled program."); - auto method_test = get_method_test( + auto method_test = get_method_test_suite( bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr), method_name); @@ -298,7 +299,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( } auto bundled_expected_outputs = - method_test.get()->test_sets()->Get(testset_idx)->expected_outputs(); + method_test.get()->test_cases()->Get(testset_idx)->expected_outputs(); for (size_t output_idx = 0; output_idx < method.outputs_size(); output_idx++) {