Skip to content

Commit

Permalink
to_edge_transform_and_lower (pytorch#3483)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3483

This diff introduces the to_edge_transform_and_lower API. The changes introduces are:
- Adding support to the Parititioner class to register ops that it doesn't want to be composed
- Changes to _program.py to add the implementation of to_edge_transform_and_lower()
- Added a basic test case to test that Linear, SDPA & Linear + SDPA are not decomposed when asked and the corresponding backend consumes them.

Reviewed By: kimishpatel, mcr229

Differential Revision: D56401086

fbshipit-source-id: 04262a58fc70e8191df33b4342295e56a5baf354
  • Loading branch information
tarun292 authored and facebook-github-bot committed May 28, 2024
1 parent 636c5c3 commit 2b91eba
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 15 deletions.
22 changes: 21 additions & 1 deletion exir/backend/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import MappingProxyType
from typing import Dict, List, Mapping, NamedTuple, Union
from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union

import torch

from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand Down Expand Up @@ -91,3 +93,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
"""
pass

def ops_to_not_decompose(
self,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Returns a list of operator names that should not be decomposed. When these ops are
registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
guaranteed that the program that the backend receives will not have any of these ops
decomposed.
Returns:
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
which will be called for each node in the graph that users can use as a filter for certain
nodes that should be continued to be decomposed even though the op they correspond to is
in the list returned by ops_to_not_decompose.
"""
return ([], None)
16 changes: 10 additions & 6 deletions exir/backend/test/backend_with_compiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ def preprocess(
processed_bytes = ""
number_of_instruction = 0
debug_handle_map = {}
match_ops = [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.add.Tensor,
torch.ops.aten.sin.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]

for node in edge_program.graph.nodes:
if node.op == "call_function":
# TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
if (
node.target == exir_ops.edge.aten.sin.default
or node.target == exir_ops.edge.aten.mm.default
or node.target == exir_ops.edge.aten.add.Tensor
or node.target == torch.ops.aten.sin.default
):
if node.target in match_ops:
simple_op = DemoOp(
node.target.__name__,
int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),
Expand Down
74 changes: 73 additions & 1 deletion exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, final
from typing import Callable, Dict, final, List, Optional, Tuple

import torch
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
Expand Down Expand Up @@ -71,6 +71,7 @@ def _partition_graph_module(
for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)

return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down Expand Up @@ -121,3 +122,74 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
)


ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
]

edge_ops_non_decomposed = [
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]


class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in edge_ops_non_decomposed


@final
class NonDecompTestPartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
"""

def __init__(self) -> None:
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[CompileSpec("max_value", bytes([4]))],
)

def ops_to_not_decompose(
self,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
def filter_ops(node: torch.fx.Node) -> bool:
if node.op == "call_function" and node.target in ops_not_to_decompose:
if len(node.args) == 3:
# This means that linear has a bias which is the only linear we support in this
# demo partitioner.
return True
else:
return False

return True

return (ops_not_to_decompose, filter_ops)

def _partition_graph_module(
self,
graph_module: torch.fx.GraphModule,
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)
return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
2 changes: 2 additions & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:error",
"//executorch/exir:graph_module",
"//executorch/exir:pass_manager",
"//executorch/exir:print_program",
"//executorch/exir:schema",
Expand All @@ -36,6 +37,7 @@ python_library(
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_aten_with_edge_pass",
"//executorch/exir/passes:replace_view_copy_with_view_pass",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/verification:verifier",
Expand Down
2 changes: 2 additions & 0 deletions exir/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from executorch.exir.program._fake_program import get_fake_program
from executorch.exir.program._program import (
_to_edge,
_to_edge_transform_and_lower,
edge_to_executorch_passes,
EdgeProgramManager,
ExecutorchProgram,
Expand All @@ -22,6 +23,7 @@
"ExecutorchProgram",
"_to_edge",
"to_edge",
"_to_edge_transform_and_lower",
"edge_to_executorch_passes",
"EdgeProgramManager",
"ExecutorchProgramManager",
Expand Down
Loading

0 comments on commit 2b91eba

Please sign in to comment.