Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - gMLP Enablement (pytorch#3774)
Browse files Browse the repository at this point in the history
Summary:
- Enable gMLP_s16_224
- Adding new OPs: split_with_sizes
- Added test cases for model and new OP

Pull Request resolved: pytorch#3774

Reviewed By: kirklandsign

Differential Revision: D58001291

Pulled By: cccclai

fbshipit-source-id: 7f4c7f85aa80b0c6b1f1c220f26ede88b6592d60
  • Loading branch information
winskuo-quic authored and facebook-github-bot committed May 31, 2024
1 parent 0412dea commit 70e3395
Show file tree
Hide file tree
Showing 17 changed files with 490 additions and 40 deletions.
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
op_slice_copy,
op_softmax,
op_space_to_depth,
op_split_with_sizes,
op_sqrt,
op_squeeze,
op_sub,
Expand Down Expand Up @@ -95,6 +96,7 @@
op_slice_copy,
op_softmax,
op_space_to_depth,
op_split_with_sizes,
op_squeeze,
op_sqrt,
op_sub,
Expand Down
35 changes: 17 additions & 18 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,8 @@ def get_data_type(
self,
tensor: torch.Tensor,
quant_config: Dict,
is_tensor: bool,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config and is_tensor:
if quant_config:
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
unsigned = quant_config["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
Expand All @@ -234,8 +233,8 @@ def get_data_type(
else:
quant_config["dtype"] = torch.int16
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
else:
return QNN_TENSOR_TYPE_MAP[tensor.dtype]

return QNN_TENSOR_TYPE_MAP[tensor.dtype]

def define_custom_tensor_wrapper(
self,
Expand All @@ -247,10 +246,11 @@ def define_custom_tensor_wrapper(
dims: torch.Size,
tensor: torch.Tensor,
is_fake_tensor: bool,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
if node_name in nodes_to_wrappers:
return nodes_to_wrappers[node_name]
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
if is_fake_tensor:
tensor_wrapper = PyQnnWrapper.TensorWrapper(
node_name,
Expand All @@ -266,18 +266,19 @@ def define_custom_tensor_wrapper(
else:
# Can implement non-fake tensor when there is a need
return None
nodes_to_wrappers[node_name] = tensor_wrapper
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper

def define_tensor(
self,
node: torch.fx.Node,
tensor: torch.Tensor,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
is_tensor: bool = True,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper
Expand All @@ -293,8 +294,8 @@ def define_tensor(
if node_name is None:
node_name = node.name

if node_name in nodes_to_wrappers:
return nodes_to_wrappers[node_name]
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
tensor_name = node.name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
Expand All @@ -303,7 +304,7 @@ def define_tensor(
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
)
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
Expand Down Expand Up @@ -334,13 +335,13 @@ def define_tensor(
tensor.detach().numpy(),
True,
)
nodes_to_wrappers[node_name] = tensor_wrapper
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
) -> PyQnnWrapper.PyQnnOpWrapper:
"""Convert torch.fx.Node to OpWrapper"""
raise NotImplementedError("NodeVisitor must be extended!")
Expand Down Expand Up @@ -372,10 +373,8 @@ def generate_node_to_external_map(
if is_graph_input(node, edge_program):
node_to_external_map[node] = len(node_to_external_map)
for node in edge_program.graph_module.graph.nodes:
if node.op == "output":
for output_nodes in node.args:
for output_node in output_nodes:
node_to_external_map[output_node] = len(node_to_external_map)
if is_graph_output(node):
node_to_external_map[node] = len(node_to_external_map)
return node_to_external_map


Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _define_conv1d(
is_input_tensor=True,
)
unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous()
dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs, True)
dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs)
unsqueeze_output_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name + "_unsqueeze",
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
Expand Down Expand Up @@ -186,7 +186,7 @@ def _define_conv1d(
)
conv_output_tensor = self.get_tensor(node, node)
conv_output_tensor = conv_output_tensor.unsqueeze(1).contiguous()
dtype = self.get_data_type(conv_output_tensor, input_quant_configs, True)
dtype = self.get_data_type(conv_output_tensor, input_quant_configs)
conv_output_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name + "_squeeze",
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,9 @@ def define_node(
raise AssertionError(
f"Invalid number of index for {node.name }: {len(node.args[1])}"
)
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
idx = node.args[1]
# to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
nodes_to_wrappers[node.name] = {
0: nodes_to_wrappers.get(node.args[0].name).get(idx)
}
return
94 changes: 94 additions & 0 deletions backends/qualcomm/builders/op_split_with_sizes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class SplitWithSizes(NodeVisitor):
target = ["aten.split_with_sizes.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:

input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)

input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
input_tensor_wrappers = [input_tensor_wrapper]

# split_with_sizes will return a tuple since it has multiple outputs
output_tensor_wrappers = []
for index in range(len(node.meta["val"])):
output_tensor = self.get_tensor(node, node, index)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
wrapper_idx=index,
)
output_tensor_wrappers.append(output_tensor_wrapper)

chunks = cast(List[int], node.args[1])
split_indices = []
sum = 0
# Edge represents chunks by specifying the size of each chunk
# QNN represents chunks by specifying the index to split chunks
for index, _value in enumerate(chunks[:-1]):

sum = sum + chunks[index]
split_indices.append(sum)

split_indices_shape = [len(split_indices)]
dim = cast(int, node.args[2])
if dim < 0:
dim = dim % len(input_tensor.shape)

if "axis_order" in node.meta:
dim = node.meta["axis_order"].index(dim)
split_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpSplit.op_name,
)
split_op.AddInputTensors(input_tensor_wrappers)
split_op.AddOutputTensors(output_tensor_wrappers)
split_op.AddTensorParam(
OpSplit.param_split_index,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(split_indices_shape),
split_indices_shape,
np.array(split_indices, dtype=np.uint32),
True,
)

split_op.AddScalarParam(
OpSplit.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(dim)},
)
return split_op
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ class Mode(IntEnum):
CRD = 1


class OpSplit:
op_name: str = "Split"
param_axis: str = "axis"
param_split_index: str = "split_index"


@dataclass(init=False, frozen=True)
class OpSqueeze:
op_name: str = "Squeeze"
Expand Down
5 changes: 4 additions & 1 deletion backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def is_graph_output(tensor: torch.fx.Node) -> bool:
tensor: EdgeIR Tensor that is being checked for graph input
"""
for user in tensor.users.keys():
if user.op == "output":
# getitem node is skiped, check the op_skip_ops.py
if user.op == "output" or (
user.target.__name__ == "getitem" and is_graph_output(user)
):
return True
return False

Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from collections import defaultdict
from typing import Any, Dict, List

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
)

self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = {}
self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict)
self.qnn_manager = PyQnnManager.QnnManager(
generate_qnn_executorch_option(compiler_specs)
)
Expand Down
53 changes: 52 additions & 1 deletion backends/qualcomm/passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

from torch.fx.passes.utils.source_matcher_utils import (
get_source_partitions,
SourcePartition,
Expand Down Expand Up @@ -92,6 +93,7 @@ def _convert_to_linear(
if bias_node:
args.append(bias_node)

# We need a view copy node after linear op
with gm.graph.inserting_before(output):
linear_node = gm.graph.create_node(
"call_function", self.linear, tuple(args)
Expand All @@ -104,6 +106,52 @@ def _convert_to_linear(
for user in fn_node.users.copy():
user.replace_input_with(fn_node, linear_node)

# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
# TODO: Find a more general conditional statement.
if (
fn_node.target == self.add
and linear_node.meta["val"].dim() == 3
and linear_node.meta["val"].shape[0] == 1
):
squeeze_dim = linear_node.meta["val"].shape[1:]
linear_node.meta["val"] = torch.squeeze(linear_node.meta["val"], 0)
with gm.graph.inserting_after(input_node):
input_users = list(input_node.users.keys())
squeeze_dim = linear_node.meta["val"].shape
squeeze_view_copy_node = gm.graph.create_node(
"call_function",
self.view_copy,
(
input_node,
squeeze_dim,
),
)
squeeze_view_copy_node.meta = linear_node.meta
for user in input_users:
if user == linear_node:
user.replace_input_with(input_node, squeeze_view_copy_node)
with gm.graph.inserting_after(output):
output_users = list(linear_node.users.keys())
unsqueeze_dim = output.args[0].meta["val"].shape
unsqueeze_view_copy_node = gm.graph.create_node(
"call_function",
self.view_copy,
(
linear_node,
unsqueeze_dim,
),
)
unsqueeze_view_copy_node.meta = output.args[0].meta
for user in output_users:
user.replace_input_with(linear_node, unsqueeze_view_copy_node)
if "quant_attrs" in linear_node.meta:
squeeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[
"quant_attrs"
]
unsqueeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[
"quant_attrs"
]

def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
# weight -> permute -> input of mm
Expand Down Expand Up @@ -133,7 +181,10 @@ def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.No
ret = [input_node, weight_node, bmm_node]
if add_node:
bias_node = add_node[0].args[1]
ret += bias_node
ret = [input_node, weight_node, add_node[0], bias_node]
else:
ret = [input_node, weight_node, bmm_node]

return ret

def _convert(self, graph_module: torch.fx.GraphModule):
Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.split_with_sizes.default,
*q_ops,
*dq_ops,
_operator.getitem,
Expand Down Expand Up @@ -142,7 +143,7 @@ def is_edge_condition(self, node):
),
(
node.op != "output"
and not isinstance(node.meta["val"], tuple)
and not isinstance(node.meta["val"], (tuple, list))
and len(node.meta["val"].shape) == 0
),
is_parameter(node, self.edge_program),
Expand Down
Loading

0 comments on commit 70e3395

Please sign in to comment.