Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
Erik-Lundell committed Oct 17, 2024
1 parent 1c298c7 commit f418036
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 28 deletions.
80 changes: 72 additions & 8 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,47 @@
from typing import cast

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.library import impl, Library

# Define lib with passthrough operators. The operators have no real meaning in edge IR
# except for argument validaiton and a passthrough output. The operators will be used
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
lib = Library("passthrough_to_tosa", "DEF")
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
# as we also need transpose the data into the correct data format.
# By utilizing an edge IR passthrough operator we can keep the edge program in
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")


@impl(lib, "_transpose")
def _transpose_impl(*args, **kwargs):
# Validate length of dim_order array
dim = args[1]
assert len(dim) <= 4
# Pass-through in edge-IR
return args[0]


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
The annotated tosa_dim_order is used to permute the node's shape such that it
gives a TOSA-compliant shape.
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
when a transition between 3D and 4D tensors happen.
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
"""

NHWC_order = (0, 2, 3, 1)
NHWC_inverse_order = (0, 3, 1, 2)
HWCM_order = (2, 3, 0, 1)

def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
returns True for dq and w in the following sequences;
Expand All @@ -48,9 +76,39 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):

return False

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
input_node = node.args[0]
if input_node.meta["val"].dim() == 4:
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(input_node, list(self.NHWC_inverse_order)),
)
permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
node.replace_input_with(input_node, permute_node)

if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
if node.meta["val"].dim() == 4:
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(node, list(self.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = self.NHWC_order
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)

def call(self, graph_module: torch.fx.GraphModule):
NHWC_Order = (0, 2, 3, 1)
HWCM_Order = (2, 3, 0, 1)
for node in graph_module.graph.nodes:
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
Expand All @@ -59,14 +117,20 @@ def call(self, graph_module: torch.fx.GraphModule):
else:
node_data = node.meta["val"].data

if len(node_data.shape) == 4:
dim_order = NHWC_Order
if node_data.dim() == 4:
dim_order = self.NHWC_order
if self.is_weight_node_for_depthwise_conv2d(node):
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = HWCM_Order
dim_order = self.HWCM_order
else:
dim_order = tuple(range(node_data.dim()))
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
# 3D (NCH) -> 4D (NHWC)
self.insert_tosa_transposes(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
op_squeeze,
op_sub,
op_sum,
op_transpose,
op_unsqueeze,
op_view,
)
42 changes: 42 additions & 0 deletions backends/arm/operators/op_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class TransposeVisitor(NodeVisitor):
"""
This node visitor targets the _transpose op defined in the
passthrough_to_tosa library. Used when switching between tosa_dim_orders.
Inserts a TOSA TRANSPOSE.
"""

target = "_transpose"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
output_rank = len(output.shape)
perms = [dim % output_rank for dim in inputs[1].special]
attr = ts.TosaSerializerAttribute()
attr.TransposeAttribute(perms)
tosa_graph.addOperator(
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
)
12 changes: 5 additions & 7 deletions backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from executorch.backends.arm.test.tester.arm_tester import ArmTester

from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized


Expand Down Expand Up @@ -77,14 +78,14 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl
)

def _test_expand_ethosu_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple
self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand All @@ -104,17 +105,14 @@ def test_expand_tosa_MI(self, test_input, multiples):
def test_expand_tosa_BI(self, test_input, multiples):
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))

# Expected failure since tosa.TILE is unsupported by Vela.
@parameterized.expand(Expand.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_expand_u55_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
self.Expand(), common.get_u55_compile_spec(), (test_input, multiples)
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_expand_u85_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
self.Expand(), common.get_u85_compile_spec(), (test_input, multiples)
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ def test_repeat_tosa_BI(self, test_input, multiples):
self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))

@parameterized.expand(Repeat.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_repeat_u55_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
)

@parameterized.expand(Repeat.test_parameters)
@unittest.expectedFailure # TODO: MLBEDSW-9386
def test_repeat_u85_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples)
Expand Down
10 changes: 6 additions & 4 deletions backends/arm/test/ops/test_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def forward(self, x: torch.Tensor, dim: int):

class SqueezeDims(torch.nn.Module):
test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [
(torch.randn(1, 1, 5), (0, 1)),
(torch.randn(1, 5, 5, 1), (0, -1)),
(torch.randn(1, 5, 1, 5), (0, -2)),
]
Expand All @@ -47,6 +48,7 @@ def forward(self, x: torch.Tensor, dims: tuple[int]):

class Squeeze(torch.nn.Module):
test_parameters: list[tuple[torch.Tensor]] = [
(torch.randn(1, 1, 5),),
(torch.randn(1, 5, 5, 1),),
(torch.randn(1, 5, 1, 5),),
]
Expand All @@ -64,7 +66,7 @@ def _test_squeeze_tosa_MI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({export_target: 1})
Expand All @@ -86,7 +88,7 @@ def _test_squeeze_tosa_BI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand Down Expand Up @@ -184,7 +186,7 @@ def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int):
@parameterized.expand(SqueezeDim.test_parameters)
def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int):
self._test_squeeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
self.SqueezeDim(),
(test_tensor, dim),
"torch.ops.aten.squeeze.dim",
Expand Down Expand Up @@ -214,7 +216,7 @@ def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
@parameterized.expand(SqueezeDims.test_parameters)
def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
self._test_squeeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(),
self.SqueezeDims(),
(test_tensor, dims),
"torch.ops.aten.squeeze.dims",
Expand Down
12 changes: 6 additions & 6 deletions backends/arm/test/ops/test_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class TestSimpleUnsqueeze(unittest.TestCase):
class Unsqueeze(torch.nn.Module):
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 4, 3)]
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)]
test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes]

def forward(self, x: torch.Tensor, dim):
Expand All @@ -40,7 +40,7 @@ def _test_unsqueeze_tosa_MI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({"torch.ops.aten.unsqueeze.default": 1})
Expand All @@ -59,7 +59,7 @@ def _test_unsqueeze_tosa_BI_pipeline(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
Expand Down Expand Up @@ -102,18 +102,18 @@ def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor):
def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0))

@parameterized.expand(Unsqueeze.test_parameters)
@parameterized.expand(Unsqueeze.test_parameters[:-1])
def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_ethosu_BI_pipeline(
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
common.get_u55_compile_spec(),
self.Unsqueeze(),
(test_tensor, 0),
)

@parameterized.expand(Unsqueeze.test_parameters)
def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor):
self._test_unsqueeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(),
self.Unsqueeze(),
(test_tensor, 0),
)
2 changes: 1 addition & 1 deletion examples/arm/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function setup_vela() {
if [[ ! -e ethos-u-vela ]]; then
git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela
repo_dir="${root_dir}/ethos-u-vela"
base_rev=fe0eaa55c5ed319f78c01978f3b40eb11a9bcb38
base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f
patch_repo
fi
cd "${root_dir}/ethos-u-vela"
Expand Down

0 comments on commit f418036

Please sign in to comment.