diff --git a/exir/backend/partitioner.py b/exir/backend/partitioner.py index 2482e6acc6..50cf9b8cf2 100644 --- a/exir/backend/partitioner.py +++ b/exir/backend/partitioner.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from types import MappingProxyType -from typing import Dict, List, Mapping, NamedTuple, Union +from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union + +import torch from executorch.exir.backend.backend_details import enforcedmethod from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -91,3 +93,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers. """ pass + + def ops_to_not_decompose( + self, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Returns a list of operator names that should not be decomposed. When these ops are + registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be + guaranteed that the program that the backend receives will not have any of these ops + decomposed. + + Returns: + List[torch._ops.OpOverload]: a list of operator names that should not be decomposed. + Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide + which will be called for each node in the graph that users can use as a filter for certain + nodes that should be continued to be decomposed even though the op they correspond to is + in the list returned by ops_to_not_decompose. + """ + return ([], None) diff --git a/exir/backend/test/backend_with_compiler_demo.py b/exir/backend/test/backend_with_compiler_demo.py index e07a8bbaa7..5f7d178be9 100644 --- a/exir/backend/test/backend_with_compiler_demo.py +++ b/exir/backend/test/backend_with_compiler_demo.py @@ -83,15 +83,19 @@ def preprocess( processed_bytes = "" number_of_instruction = 0 debug_handle_map = {} + match_ops = [ + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.add.Tensor, + torch.ops.aten.sin.default, + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.scaled_dot_product_attention.default, + ] + for node in edge_program.graph.nodes: if node.op == "call_function": # TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect. - if ( - node.target == exir_ops.edge.aten.sin.default - or node.target == exir_ops.edge.aten.mm.default - or node.target == exir_ops.edge.aten.add.Tensor - or node.target == torch.ops.aten.sin.default - ): + if node.target in match_ops: simple_op = DemoOp( node.target.__name__, int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()), diff --git a/exir/backend/test/op_partitioner_demo.py b/exir/backend/test/op_partitioner_demo.py index cba9a4f8c1..b92e065062 100644 --- a/exir/backend/test/op_partitioner_demo.py +++ b/exir/backend/test/op_partitioner_demo.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, final +from typing import Callable, Dict, final, List, Optional, Tuple import torch from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -71,6 +71,7 @@ def _partition_graph_module( for _, submodule, _ in get_control_flow_submodules(graph_module): ret_partition_tags = self._partition_graph_module(submodule) partition_tags.update(ret_partition_tags) + return partition_tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: @@ -121,3 +122,74 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags ) + + +ops_not_to_decompose = [ + torch.ops.aten.linear.default, + torch.ops.aten.scaled_dot_product_attention.default, +] + +edge_ops_non_decomposed = [ + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.scaled_dot_product_attention.default, +] + + +class OpsToNotDecomposeOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in edge_ops_non_decomposed + + +@final +class NonDecompTestPartitioner(Partitioner): + """ + Partitions all add/mul nodes regardless of order + """ + + def __init__(self) -> None: + self.op_support = any_chain(OpsToNotDecomposeOperatorSupport()) + self.delegation_spec = DelegationSpec( + BackendWithCompilerDemo.__name__, + [CompileSpec("max_value", bytes([4]))], + ) + + def ops_to_not_decompose( + self, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + def filter_ops(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target in ops_not_to_decompose: + if len(node.args) == 3: + # This means that linear has a bias which is the only linear we support in this + # demo partitioner. + return True + else: + return False + + return True + + return (ops_not_to_decompose, filter_ops) + + def _partition_graph_module( + self, + graph_module: torch.fx.GraphModule, + ) -> Dict[str, DelegationSpec]: + partition_tags: Dict[str, DelegationSpec] = {} + partition_list = generate_pattern_op_partitions( + graph_module, op_support=self.op_support + ) + for partition in partition_list: + for node in partition.nodes: + delegation_tag = f"tag{partition.id}" + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + + for _, submodule, _ in get_control_flow_submodules(graph_module): + ret_partition_tags = self._partition_graph_module(submodule) + partition_tags.update(ret_partition_tags) + return partition_tags + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + partition_tags = self._partition_graph_module(exported_program.graph_module) + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 5ae3cf1ac5..ef4e619e1e 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -21,6 +21,7 @@ python_library( deps = [ "//caffe2:torch", "//executorch/exir:error", + "//executorch/exir:graph_module", "//executorch/exir:pass_manager", "//executorch/exir:print_program", "//executorch/exir:schema", @@ -36,6 +37,7 @@ python_library( "//executorch/exir/passes:normalize_view_copy_base_pass", "//executorch/exir/passes:remove_graph_asserts_pass", "//executorch/exir/passes:remove_mixed_type_operators", + "//executorch/exir/passes:replace_aten_with_edge_pass", "//executorch/exir/passes:replace_view_copy_with_view_pass", "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/verification:verifier", diff --git a/exir/program/__init__.py b/exir/program/__init__.py index 4d00297685..efc3ebfbe9 100644 --- a/exir/program/__init__.py +++ b/exir/program/__init__.py @@ -9,6 +9,7 @@ from executorch.exir.program._fake_program import get_fake_program from executorch.exir.program._program import ( _to_edge, + _to_edge_transform_and_lower, edge_to_executorch_passes, EdgeProgramManager, ExecutorchProgram, @@ -22,6 +23,7 @@ "ExecutorchProgram", "_to_edge", "to_edge", + "_to_edge_transform_and_lower", "edge_to_executorch_passes", "EdgeProgramManager", "ExecutorchProgramManager", diff --git a/exir/program/_program.py b/exir/program/_program.py index c5afe01169..9fc6b36b3b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -20,6 +20,7 @@ from executorch.exir.emit import emit_program, EmitterOutput from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap from executorch.exir.error import ExportError +from executorch.exir.graph_module import get_control_flow_submodules from executorch.exir.pass_manager import PassType from executorch.exir.passes import ( base_post_op_replace_passes, @@ -37,6 +38,7 @@ ) from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators +from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge from executorch.exir.passes.replace_view_copy_with_view_pass import ( ReplaceViewCopyWithViewPass, ) @@ -69,6 +71,17 @@ Val = Any +from torch.library import Library + +# This is the reserved namespace that is used to register ops to that will +# be prevented from being decomposed during to_edge_transform_and_lower. +edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" +lib = Library(edge_no_decomp_namespace, "DEF") +# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace. +aten_op_to_transform_op = {} +# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops. +transform_op_to_aten_op = {} + def _get_updated_range_constraints(gm): def get_shape_env(gm): @@ -656,12 +669,15 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType] def _generate_edge_program( - name: str, config: EdgeCompileConfig, program: ExportedProgram + name: str, + config: EdgeCompileConfig, + program: ExportedProgram, + ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, ) -> ExportedProgram: if config._check_ir_validity: try: - EXIRATenDialectVerifier()(program.graph_module) + EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in ATen dialect.") raise e @@ -695,6 +711,7 @@ def _generate_edge_program( verifier=EXIREdgeDialectVerifier( edge_compile_config=config, class_only=True, + exception_list=ops_set_to_not_decompose, ), constants=program.constants, ) @@ -705,6 +722,269 @@ def _generate_edge_program( return edge_program +def _replace_aten_ops_with_transformed_ops( + name: str, + program: ExportedProgram, + partitioner, +): + + ops_to_not_decompose = set() + partitioners = partitioner.get(name) + if partitioners is None: + return + + # Iterate through the graph and replace the aten ops with the corresponding + # transformed ops. + for partitioner in partitioners: + ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose() + + for op_aten in ops_set_to_not_decompose: + _register_no_decomp_op(op_aten) + + for node in program.graph.nodes: + is_op_supported = check_op_support(node) if check_op_support else True + if ( + node.op == "call_function" + and node.target in ops_set_to_not_decompose + and is_op_supported + ): + ops_to_not_decompose.add(node.target) + node.target = aten_op_to_transform_op[node.target] + + for _, submod, _ in get_control_flow_submodules(program.graph_module): + for node in submod.graph.nodes: + is_op_supported = check_op_support(node) if check_op_support else True + if ( + node.op == "call_function" + and node.target in ops_set_to_not_decompose + and is_op_supported + ): + ops_to_not_decompose.add(node.target) + node.target = aten_op_to_transform_op[node.target] + + return ops_to_not_decompose + + +def _restore_transformed_ops_to_aten_ops(program: ExportedProgram): + # Iterate through the graph and replace back the transformed ops with their + # corresponding aten ops. + for node in program.graph.nodes: + if node.op == "call_function" and str(node.target) in transform_op_to_aten_op: + node.target = transform_op_to_aten_op[str(node.target)] + for _, submod, _ in get_control_flow_submodules(program.graph_module): + for node in submod.graph.nodes: + if ( + node.op == "call_function" + and str(node.target) in transform_op_to_aten_op + ): + node.target = transform_op_to_aten_op[str(node.target)] + + +# Returns the op in edge_no_decomp_namespace namespace for the aten +# op that is passed in. +def _get_transformed_op(op_aten): + op_name = op_aten._schema.name.split("::")[1] + overload_name = op_aten._schema.overload_name + assert hasattr( + torch.ops, edge_no_decomp_namespace + ), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered." + op_namespace = getattr(torch.ops, edge_no_decomp_namespace) + op = getattr(op_namespace, op_name) + return getattr(op, overload_name) + + +# Registers the op in edge_no_decomp_namespace namespace for the aten +# op that is passed in if it is not already cached in the table. +def _register_no_decomp_op(op_aten): + # Check if the op is already cached in the table. If not, then we need to + # create a new op in the edge_no_decomp_namespace namespace. + if aten_op_to_transform_op.get(op_aten) is None and isinstance( + op_aten, torch._ops.OpOverload + ): + # Extract the schema from the aten op. + op_schema = str(op_aten._schema).split("::")[1] + op_name = op_aten._schema.name.split("::")[1] + # Define an op in the edge_no_decomp_namespace namespace with the aten schema. + lib.define(op_schema) + # Define the implementation of the op in the edge_no_decomp_namespace namespace. + # Important to note that the implementation of the op is the same as the aten op. + lib.impl(op_name, op_aten, "CompositeExplicitAutograd") + # Cache the aten op and transformed op in their corresponding tables for future use. + aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten) + transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten + + +def _sanity_check_graph_for_non_decomp_ops( + name: str, + program: ExportedProgram, + ops_set_to_not_decompose, + check_op_support, + generate_error=False, + partitioner_name=None, +): + warning_str = f"Found {ops_set_to_not_decompose} in edge dialect program {name}." + if partitioner_name is not None: + warning_str += f" This op was registered by the partitioner {partitioner_name} to not be decomposed." + + # Check that the ops that were registered to not be decomposed are not present in the + # graph anymore as the transform passes and backends should have consumed them by now. + ops_set_to_not_decompose = { + aten_to_edge(op) for op in ops_set_to_not_decompose + }.union(ops_set_to_not_decompose) + for node in program.graph_module.graph.nodes: + is_op_supported = check_op_support(node) if check_op_support else True + if ( + node.op == "call_function" and node.target in ops_set_to_not_decompose + ) and is_op_supported: + if generate_error: + raise RuntimeError(warning_str) + else: + logging.warning(warning_str) + for _, submod, _ in get_control_flow_submodules(program.graph_module): + for node in submod.graph.nodes: + is_op_supported = check_op_support(node) if check_op_support else True + if ( + node.op == "call_function" and node.target in ops_set_to_not_decompose + ) and is_op_supported: + if generate_error: + raise RuntimeError(warning_str) + else: + logging.warning(warning_str) + + +def _get_ops_to_not_decompose(partitioners, ops_set_to_not_decompose_by_partitioner): + ops_set_to_not_decompose = set() + for partitioner in partitioners: + ops_set_to_not_decompose = ops_set_to_not_decompose.union( + ops_set_to_not_decompose_by_partitioner[partitioner][0] + ) + return ops_set_to_not_decompose + + +def _to_edge_transform_and_lower( + programs: Union[ExportedProgram, Dict[str, ExportedProgram]], + transform_passes: Optional[ + Union[Sequence[PassType], Dict[str, Sequence[PassType]]] + ] = None, + partitioner: Optional[ + Union[List[Partitioner], Dict[str, List[Partitioner]]] + ] = None, + constant_methods: Optional[Dict[str, Any]] = None, + compile_config: Optional[EdgeCompileConfig] = None, +) -> "EdgeProgramManager": + """ + :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of + exported programs in ATen dialect. It differs fundamentally from to_edge in that it + combines the conversion of the ATen dialect to the edge dialect program, then running + the transformation passes and then subsequently lowering the programs to their + corresponding backends all in a single pass. + This is fundamentally useful for lowering to backends that have ops registered that they + do not want to be decomposed and thus rely on matching with these non-decomposed ops. For + these sorts of backends this is the *only* API that should be used to lower to the edge + dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent + or wrong behavior. + + Args: + programs: Can be a single ExportedProgram or a dictionary mapping function names + to their corresponding ExportedPrograms. If only a single ExportedProgram is + provided it will be assigned the name "forward". + + transform_passes: The passes can either be a list of passes, or a dictionary + mapping method names to lists of passes. If it is just a list of passes, all methods + in the given EdgeProgramManager will be transformed with the provided passes. If it + is a dictionary, only method names specified in the dictionary will be transformed + with their corresponding passes. + + partitioner: The partitioner can either be a Partitioner subclass instance, or a + dictionary mapping method names to Partitioner subclass instance. If it is a + Partitioner subclass, all programs in the given EdgeProgramManager will be lowered + using the given partitioner. If it is a dictionary, only method names specified in + the dictionary will be lowered with the given partitioner. + + constant_methods: An optional dictionary of method name to the constant value + returned by that method in eager mode. Often used to store config information on + Edge models. + + compile_config: An optional argument used to provide greater control over the + transformation to edge dialect process. + + Returns: + EdgeProgramManager + """ + ops_set_to_not_decompose = set() + + assert not isinstance(constant_methods, EdgeCompileConfig) + config = compile_config or EdgeCompileConfig() + if not isinstance(programs, dict): + aten_programs = {"forward": programs} + else: + aten_programs = programs + + if not isinstance(partitioner, dict) and partitioner is not None: + partitioner = {"forward": partitioner} + else: + partitioner = {} + + ops_set_to_not_decompose_by_program = {} + edge_programs: Dict[str, ExportedProgram] = {} + for name, program in aten_programs.items(): + if partitioner is not None: + ops_set_to_not_decompose_by_program[name] = ( + _replace_aten_ops_with_transformed_ops(name, program, partitioner) + ) + program = program.run_decompositions(_default_decomposition_table()) + + _restore_transformed_ops_to_aten_ops(program) + + edge_programs[name] = program + + edge_programs[name] = _generate_edge_program( + name, + config, + program, + list(ops_set_to_not_decompose_by_program.get(name, [])), + ) + + edge_manager = EdgeProgramManager( + edge_programs, + constant_methods, + config, + list(set().union(*ops_set_to_not_decompose_by_program.values())), + ) + + if transform_passes is not None: + edge_manager = edge_manager.transform(transform_passes) + + if partitioner is not None: + for name, partitioner_list in partitioner.items(): + for curr_partitioner in partitioner_list: + edge_manager = edge_manager.to_backend({name: curr_partitioner}) + curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose() + + for name, program in edge_manager._edge_programs.items(): + if config._check_ir_validity: + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + )()(program.graph_module) + + ops_set_to_not_decompose = set() + partitioners = partitioner.get(name, []) + for curr_partitioner in partitioners: + curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose() + ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) + _sanity_check_graph_for_non_decomp_ops( + name, + program, + ops_set_to_not_decompose, + check_op_support, + partitioner_name=curr_partitioner.__class__.__name__, + generate_error=True, + ) + + return edge_manager + + def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, @@ -757,6 +1037,7 @@ def __init__( edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, ): """ Should not be called directly by users. User should use :func:'to_edge' instead. @@ -768,9 +1049,10 @@ def __init__( edge_programs = {"forward": edge_programs} for name, program in edge_programs.items(): try: - EXIREdgeDialectVerifier(edge_compile_config=self.compile_config)( - program.graph_module - ) + EXIREdgeDialectVerifier( + edge_compile_config=self.compile_config, + exception_list=ops_set_to_not_decompose, + )(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in aten dialect.") raise e diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index d2d37cb13b..b189e6b0f0 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -6,18 +6,23 @@ # pye-strict +import operator import unittest from typing import Any, Dict import torch -from executorch.exir import ExecutorchBackendConfig -from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.backend.test.op_partitioner_demo import ( + AddMulPartitionerDemo, + NonDecompTestPartitioner, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.error import ExportError from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ( + _to_edge_transform_and_lower, EdgeProgramManager, ExecutorchProgramManager, to_edge, @@ -28,6 +33,7 @@ _load_for_executorch_from_buffer, ) from torch.export import export, ExportedProgram +from torch.export._trace import _export from torch.library import impl, Library @@ -445,3 +451,125 @@ def _use_foo_add(a: torch.Tensor, b: torch.Tensor): # This should not raise error self._test_edge_dialect_verifier(_use_foo_add, False) + + def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module): + # This is the pre-dispatch export that we will be switching to primarily + # in the near future. The input to _to_edge_transform_and_lower needs to + # be a graph generated by this pre dispatch export. + ep = _export(model, model._get_random_inputs(), pre_dispatch=True) + edge = _to_edge_transform_and_lower( + ep, + compile_config=EdgeCompileConfig(), + partitioner=[NonDecompTestPartitioner()], + ) + for node in edge.exported_program().graph_module.graph.nodes: + # There should only be a single call_function node in the graph + # and that should be a call_delegate node. + if node.op == "call_function" and node.target != operator.getitem: + self.assertEqual( + node.target, torch.ops.higher_order.executorch_call_delegate + ) + + def test_to_edge_transform_and_lower(self): + class TestLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 16, bias=True) + + def forward(self, x): + return self.linear(x) + + @classmethod + def _get_random_inputs(cls): + x = torch.rand(8, 32) + return (x,) + + class TestSDPA(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention.default( + query, key, value + ) + + @classmethod + def _get_random_inputs(cls): + d_k = 64 + batch = 16 + seq_len = 10 + query = torch.rand(batch, seq_len, d_k) + key = torch.rand(batch, seq_len, d_k) + value = torch.rand(batch, seq_len, d_k) + return (query, key, value) + + class TestLinearSDPACombined(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 16, bias=True) + + def forward(self, x, query, key, value): + x = self.linear(x) + return ( + x, + torch.ops.aten.scaled_dot_product_attention.default( + query, key, value + ), + ) + + @classmethod + def _get_random_inputs(cls): + return TestLinear._get_random_inputs() + TestSDPA._get_random_inputs() + + self._test_model_with_non_decomp_partitioner(TestLinear()) + + self._test_model_with_non_decomp_partitioner(TestSDPA()) + + self._test_model_with_non_decomp_partitioner(TestLinearSDPACombined()) + + def test_to_edge_transform_and_lower_with_exception(self): + class TestLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 16, bias=True) + self.linear_no_bias = torch.nn.Linear(32, 16, bias=False) + + def forward(self, x): + return (self.linear(x), self.linear_no_bias(x)) + + @classmethod + def _get_random_inputs(cls): + x = torch.rand(8, 32) + return (x,) + + model = TestLinear() + ep = _export(model, model._get_random_inputs(), pre_dispatch=True) + edge = _to_edge_transform_and_lower( + ep, + compile_config=EdgeCompileConfig(), + partitioner=[NonDecompTestPartitioner()], + ) + + def count_nodes(graph_module, target): + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + count += 1 + return count + + # There should be 1 call_delegate node and 1 node for aten.mm.default for the + # linear that doesn't have a bias which was decomposed as the partitioner + # said this node wasn't supported. + self.assertEqual( + count_nodes( + edge.exported_program().graph_module, + torch.ops.higher_order.executorch_call_delegate, + ), + 1, + ) + self.assertEqual( + count_nodes( + edge.exported_program().graph_module, exir_ops.edge.aten.mm.default + ), + 1, + )