Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Refactor & centralize common keywords (py…
Browse files Browse the repository at this point in the history
…torch#4357)

Summary:
- Summarize the QCOM specific keywords
- Replace with the hard code part in qualcomm code base

Pull Request resolved: pytorch#4357

Reviewed By: digantdesai

Differential Revision: D60118046

Pulled By: cccclai

fbshipit-source-id: 6a0aea6c0d45536e476c9ee01a8f4ac5c9005df8
  • Loading branch information
chuntl authored and facebook-github-bot committed Jul 24, 2024
1 parent 6c69ebd commit 56120f9
Show file tree
Hide file tree
Showing 43 changed files with 359 additions and 252 deletions.
44 changes: 27 additions & 17 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS_ORDER,
QCOM_BITWIDTH,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
QCOM_REQUANTIZE,
QCOM_SCALE_OFFSET,
QCOM_SCALES,
QCOM_ZERO_POINTS,
)

from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -89,15 +99,15 @@ def _get_tensor(node, index):
return node.meta["val"]

tensor = _get_tensor(input_node, idx)
if len(tensor.shape) != 0 and "axis_order" in op_node.meta:
tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous()
if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta:
tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
return tensor

def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)

scales = quant_attrs["scales"]
zero_points = quant_attrs["zero_points"]
scales = quant_attrs[QCOM_SCALES]
zero_points = quant_attrs[QCOM_ZERO_POINTS]
assert len(scales) == len(
zero_points
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
Expand All @@ -120,13 +130,13 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
else:
quant_config["axis"] = quant_attrs["axis"]

quant_config["scale_offset"] = scale_offset
quant_config[QCOM_SCALE_OFFSET] = scale_offset
# special case for 4 bits
if (
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
):
quant_config["bitwidth"] = 4
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
Expand All @@ -145,7 +155,7 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
):
quant_config["bitwidth"] = 4
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
quant_config,
Expand All @@ -158,36 +168,36 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
def get_quant_encoding_conf(
self, node: torch.fx.Node, is_input_tensor: bool = False
) -> Tuple[Any, Dict]:
if not node.meta.get("quant_attrs", None):
if not node.meta.get(QCOM_QUANT_ATTRS, None):
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
)
quant_attrs = (
node.meta["requantize"]
if "requantize" in node.meta and is_input_tensor
else node.meta["quant_attrs"]
node.meta[QCOM_REQUANTIZE]
if QCOM_REQUANTIZE in node.meta and is_input_tensor
else node.meta[QCOM_QUANT_ATTRS]
)
if quant_attrs["encoding"] in PER_CHANNEL_ENCODING:
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
return self.make_qnn_per_channel_config(node, quant_attrs)

return self.make_qnn_per_tensor_config(quant_attrs)

def get_quant_tensor_value(
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
) -> torch.Tensor:
if quant_attrs["encoding"] in PER_TENSOR_ENCODING:
if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
scale = quant_attrs["scale"]
zero_point = quant_attrs["zero_point"]
else: # per channel case
scale = quant_attrs["scales"]
zero_point = quant_attrs["zero_points"]
scale = quant_attrs[QCOM_SCALES]
zero_point = quant_attrs[QCOM_ZERO_POINTS]

dtype = quant_configs["dtype"]

tensor = tensor.div(scale).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if quant_configs.get("bitwidth") == 4:
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
return tensor
Expand Down Expand Up @@ -315,7 +325,7 @@ def define_tensor(
if quant_configs:
tensor = self.get_quant_tensor_value(
tensor,
node.meta["quant_attrs"],
node.meta[QCOM_QUANT_ATTRS],
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -132,12 +133,12 @@ def define_node(
avg_pool2d_op.AddScalarParam(
OpPoolAvg2d.param_rounding_mode,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(mode)},
{QCOM_DATA: np.uint32(mode)},
)
avg_pool2d_op.AddScalarParam(
OpPoolAvg2d.param_count_pad_for_edges,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{"data": count_include_pad},
{QCOM_DATA: count_include_pad},
)

return avg_pool2d_op
7 changes: 4 additions & 3 deletions backends/qualcomm/builders/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpConcat, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -64,8 +65,8 @@ def define_node(
if axis < 0:
axis += node.meta["val"].dim()

if "axis_order" in node.meta:
axis = node.meta["axis_order"].index(axis)
if QCOM_AXIS_ORDER in node.meta:
axis = node.meta[QCOM_AXIS_ORDER].index(axis)

concat_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
Expand All @@ -78,7 +79,7 @@ def define_node(
concat_op.AddScalarParam(
OpConcat.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(axis)},
{QCOM_DATA: np.uint32(axis)},
)

return concat_op
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -67,12 +68,12 @@ def define_node(
clamp_op.AddScalarParam(
OpReluMinMax.param_max_value,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(output_max)},
{QCOM_DATA: np.float32(output_max)},
)
clamp_op.AddScalarParam(
OpReluMinMax.param_min_value,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(output_min)},
{QCOM_DATA: np.float32(output_min)},
)

return clamp_op
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import (
Expand Down Expand Up @@ -79,7 +80,7 @@ def _add_conv_op_parameter(
conv_op.AddScalarParam(
OP.param_group,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(groups)},
{QCOM_DATA: np.uint32(groups)},
)

return conv_op
Expand Down Expand Up @@ -130,7 +131,7 @@ def _define_conv1d(
unsqueeze_op.AddScalarParam(
OpExpandDims.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(1)},
{QCOM_DATA: np.uint32(1)},
)
op_wrapper_list.append(unsqueeze_op)

Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/builders/op_depth_to_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpDepthToSpace, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -70,7 +71,7 @@ def define_node(
depth_to_space_op.AddScalarParam(
OpDepthToSpace.param_mode,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(OpDepthToSpace.Mode.CRD)},
{QCOM_DATA: np.uint32(OpDepthToSpace.Mode.CRD)},
)

return depth_to_space_op
3 changes: 2 additions & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -71,7 +72,7 @@ def define_node(
gather_op.AddScalarParam(
OpGather.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
{"data": np.int32(0)},
{QCOM_DATA: np.int32(0)},
)

return gather_op
7 changes: 4 additions & 3 deletions backends/qualcomm/builders/op_hardsigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseNeuron, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -58,19 +59,19 @@ def define_node(
hardsigmoid_op.AddScalarParam(
OpElementWiseNeuron.param_operation,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(2)},
{QCOM_DATA: np.uint32(2)},
)

# The parameter used in Pytorch definition for hardsigmoid
hardsigmoid_op.AddScalarParam(
OpElementWiseNeuron.param_alpha,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(1 / 6)},
{QCOM_DATA: np.float32(1 / 6)},
)
hardsigmoid_op.AddScalarParam(
OpElementWiseNeuron.param_beta,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(1 / 2)},
{QCOM_DATA: np.float32(1 / 2)},
)

return hardsigmoid_op
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/op_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -66,12 +67,12 @@ def define_node(
hardtanh_op.AddScalarParam(
OpReluMinMax.param_max_value,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(output_max)},
{QCOM_DATA: np.float32(output_max)},
)
hardtanh_op.AddScalarParam(
OpReluMinMax.param_min_value,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(output_min)},
{QCOM_DATA: np.float32(output_min)},
)

return hardtanh_op
3 changes: 2 additions & 1 deletion backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -91,7 +92,7 @@ def define_node(
layer_norm_op.AddScalarParam(
OpLayerNorm.param_epsilon,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(epsilon)},
{QCOM_DATA: np.float32(epsilon)},
)
layer_norm_op.AddTensorParam(
OpLayerNorm.param_axes,
Expand Down
17 changes: 12 additions & 5 deletions backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_QUANT_ATTRS,
QCOM_SCALES,
QCOM_ZERO_POINTS,
)

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -41,12 +46,14 @@ def define_node(

weight_node = node.args[1]
if (
quant_attrs := weight_node.meta.get("quant_attrs")
) and "scales" in quant_attrs:
quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS)
) and QCOM_SCALES in quant_attrs:
# Dimension of weight is [m, n], per channel quant params is [m]
# Change to [m, 1] to fit the tensor.div(s).add(z)
quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1])
quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1])
quant_attrs[QCOM_SCALES] = quant_attrs[QCOM_SCALES].reshape([-1, 1])
quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape(
[-1, 1]
)

weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
Expand All @@ -62,7 +69,7 @@ def define_node(
bias_node = node.args[2]

# TODO remove this when qnn sdk support
if "scales" in bias_node.meta.get("quant_attrs", {}):
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
print(
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
)
Expand Down
7 changes: 4 additions & 3 deletions backends/qualcomm/builders/op_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpLogSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -52,8 +53,8 @@ def define_node(
if dim < 0:
dim = dim % len(input_tensor.shape)

if "axis_order" in node.meta:
dim = node.meta["axis_order"].index(dim)
if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)

# logsoftmax only supports last dimension for now, which is channel in QNN
if dim != input_tensor.dim() - 1:
Expand All @@ -70,6 +71,6 @@ def define_node(
log_softmax_op.AddScalarParam(
OpLogSoftmax.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(dim)},
{QCOM_DATA: np.uint32(dim)},
)
return log_softmax_op
Loading

0 comments on commit 56120f9

Please sign in to comment.