diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 222c0a7cb3..9e9d48e7f4 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -9,19 +9,47 @@ from typing import cast import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm.tosa_quant_utils import dq_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from torch.library import impl, Library + +# Define lib with passthrough operators. The operators have no real meaning in edge IR +# except for argument validaiton and a passthrough output. The operators will be used +# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect +# the edge IR graph but will be lowered to a TOSA-TRANSPOSE. +lib = Library("passthrough_to_tosa", "DEF") +# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need +# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient +# as we also need transpose the data into the correct data format. +# By utilizing an edge IR passthrough operator we can keep the edge program in +# channels-first/contiguous and get the desired behavior in the TOSA lowering. +lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor") + + +@impl(lib, "_transpose") +def _transpose_impl(*args, **kwargs): + # Validate length of dim_order array + dim = args[1] + assert len(dim) <= 4 + # Pass-through in edge-IR + return args[0] class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order - that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. - The annotated tosa_dim_order is used to permute the node's shape such that it - gives a TOSA-compliant shape. + that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose + when a transition between 3D and 4D tensors happen. + The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. """ + NHWC_order = (0, 2, 3, 1) + NHWC_inverse_order = (0, 3, 1, 2) + HWCM_order = (2, 3, 0, 1) + def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): """ returns True for dq and w in the following sequences; @@ -48,9 +76,39 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): return False + def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target == exir_ops.edge.aten.squeeze_copy.dims: + input_node = node.args[0] + if input_node.meta["val"].dim() == 4: + with graph_module.graph.inserting_before(node): + permute_node = create_node( + graph_module.graph, + torch.ops.passthrough_to_tosa._transpose, + args=(input_node, list(self.NHWC_inverse_order)), + ) + permute_node.meta["tosa_dim_order"] = tuple( + range(len(input_node.meta["val"].size())) + ) + node.replace_input_with(input_node, permute_node) + + if node.target == exir_ops.edge.aten.unsqueeze_copy.default: + if node.meta["val"].dim() == 4: + with graph_module.graph.inserting_after(node): + permute_node = create_node( + graph_module.graph, + torch.ops.passthrough_to_tosa._transpose, + args=(node, list(self.NHWC_order)), + ) + permute_node.meta["tosa_dim_order"] = self.NHWC_order + node.meta["tosa_dim_order"] = (0, 1, 2, 3) + users = [user for user in node.users if user != permute_node] + for user in users: + user.replace_input_with(node, permute_node) + def call(self, graph_module: torch.fx.GraphModule): - NHWC_Order = (0, 2, 3, 1) - HWCM_Order = (2, 3, 0, 1) for node in graph_module.graph.nodes: if isinstance( node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) @@ -59,14 +117,20 @@ def call(self, graph_module: torch.fx.GraphModule): else: node_data = node.meta["val"].data - if len(node_data.shape) == 4: - dim_order = NHWC_Order + if node_data.dim() == 4: + dim_order = self.NHWC_order if self.is_weight_node_for_depthwise_conv2d(node): # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). - dim_order = HWCM_Order + dim_order = self.HWCM_order else: dim_order = tuple(range(node_data.dim())) node.meta["tosa_dim_order"] = dim_order + # Take care of cases when: + # 4D (NHWC) -> >4D (NCH) + # 3D (NCH) -> 4D (NHWC) + self.insert_tosa_transposes(graph_module) graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 855487cf7f..b79a0b645c 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -35,6 +35,7 @@ op_squeeze, op_sub, op_sum, + op_transpose, op_unsqueeze, op_view, ) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py new file mode 100644 index 0000000000..b427b3f5bf --- /dev/null +++ b/backends/arm/operators/op_transpose.py @@ -0,0 +1,42 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class TransposeVisitor(NodeVisitor): + """ + This node visitor targets the _transpose op defined in the + passthrough_to_tosa library. Used when switching between tosa_dim_orders. + Inserts a TOSA TRANSPOSE. + """ + + target = "_transpose" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + output_rank = len(output.shape) + perms = [dim % output_rank for dim in inputs[1].special] + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(perms) + tosa_graph.addOperator( + TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index e9bbea9a5e..aa13a6475c 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -21,6 +21,7 @@ from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized @@ -77,14 +78,14 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ) def _test_expand_ethosu_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple + self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple ): quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_u55_compile_spec(), + compile_spec=compile_spec, ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() @@ -104,17 +105,14 @@ def test_expand_tosa_MI(self, test_input, multiples): def test_expand_tosa_BI(self, test_input, multiples): self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples)) - # Expected failure since tosa.TILE is unsupported by Vela. @parameterized.expand(Expand.test_parameters) - @unittest.expectedFailure # TODO: MLBEDSW-9386 def test_expand_u55_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( - self.Expand(), common.get_u55_compile_spec(), (test_input, multiples) + common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) ) @parameterized.expand(Expand.test_parameters) - @unittest.expectedFailure # TODO: MLBEDSW-9386 def test_expand_u85_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( - self.Expand(), common.get_u85_compile_spec(), (test_input, multiples) + common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) ) diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 542f0d6256..1efac9f974 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -107,14 +107,12 @@ def test_repeat_tosa_BI(self, test_input, multiples): self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples)) @parameterized.expand(Repeat.test_parameters) - @unittest.expectedFailure # TODO: MLBEDSW-9386 def test_repeat_u55_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples) ) @parameterized.expand(Repeat.test_parameters) - @unittest.expectedFailure # TODO: MLBEDSW-9386 def test_repeat_u85_BI(self, test_input, multiples): self._test_repeat_ethosu_pipeline( common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples) diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index 4fe420708a..790769ca71 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -38,6 +38,7 @@ def forward(self, x: torch.Tensor, dim: int): class SqueezeDims(torch.nn.Module): test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [ + (torch.randn(1, 1, 5), (0, 1)), (torch.randn(1, 5, 5, 1), (0, -1)), (torch.randn(1, 5, 1, 5), (0, -2)), ] @@ -47,6 +48,7 @@ def forward(self, x: torch.Tensor, dims: tuple[int]): class Squeeze(torch.nn.Module): test_parameters: list[tuple[torch.Tensor]] = [ + (torch.randn(1, 1, 5),), (torch.randn(1, 5, 5, 1),), (torch.randn(1, 5, 1, 5),), ] @@ -64,7 +66,7 @@ def _test_squeeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec(), ) .export() .check_count({export_target: 1}) @@ -86,7 +88,7 @@ def _test_squeeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec(), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() @@ -184,7 +186,7 @@ def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int): @parameterized.expand(SqueezeDim.test_parameters) def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int): self._test_squeeze_ethosu_BI_pipeline( - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim", @@ -214,7 +216,7 @@ def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]): @parameterized.expand(SqueezeDims.test_parameters) def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]): self._test_squeeze_ethosu_BI_pipeline( - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(), self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims", diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 8431efa271..1cc597c066 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -27,7 +27,7 @@ class TestSimpleUnsqueeze(unittest.TestCase): class Unsqueeze(torch.nn.Module): - shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 4, 3)] + shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)] test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes] def forward(self, x: torch.Tensor, dim): @@ -40,7 +40,7 @@ def _test_unsqueeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec(), ) .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) @@ -59,7 +59,7 @@ def _test_unsqueeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec(), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() @@ -102,10 +102,10 @@ def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor): def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor): self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0)) - @parameterized.expand(Unsqueeze.test_parameters) + @parameterized.expand(Unsqueeze.test_parameters[:-1]) def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor): self._test_unsqueeze_ethosu_BI_pipeline( - common.get_u55_compile_spec(permute_memory_to_nhwc=False), + common.get_u55_compile_spec(), self.Unsqueeze(), (test_tensor, 0), ) @@ -113,7 +113,7 @@ def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor): @parameterized.expand(Unsqueeze.test_parameters) def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor): self._test_unsqueeze_ethosu_BI_pipeline( - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(), self.Unsqueeze(), (test_tensor, 0), ) diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index ae335208cd..583237729d 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -261,7 +261,7 @@ function setup_vela() { if [[ ! -e ethos-u-vela ]]; then git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela repo_dir="${root_dir}/ethos-u-vela" - base_rev=fe0eaa55c5ed319f78c01978f3b40eb11a9bcb38 + base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f patch_repo fi cd "${root_dir}/ethos-u-vela"