diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 96e3b6f97f..641e2445f2 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -11,6 +11,16 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_BITWIDTH, + QCOM_ENCODING, + QCOM_QUANT_ATTRS, + QCOM_REQUANTIZE, + QCOM_SCALE_OFFSET, + QCOM_SCALES, + QCOM_ZERO_POINTS, +) from executorch.exir.dialects._ops import ops as exir_ops @@ -89,15 +99,15 @@ def _get_tensor(node, index): return node.meta["val"] tensor = _get_tensor(input_node, idx) - if len(tensor.shape) != 0 and "axis_order" in op_node.meta: - tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous() + if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta: + tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous() return tensor def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): quant_config = copy.deepcopy(quant_attrs) - scales = quant_attrs["scales"] - zero_points = quant_attrs["zero_points"] + scales = quant_attrs[QCOM_SCALES] + zero_points = quant_attrs[QCOM_ZERO_POINTS] assert len(scales) == len( zero_points ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}" @@ -120,13 +130,13 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): else: quant_config["axis"] = quant_attrs["axis"] - quant_config["scale_offset"] = scale_offset + quant_config[QCOM_SCALE_OFFSET] = scale_offset # special case for 4 bits if ( quant_config["dtype"] == torch.int8 and quant_config["quant_max"] - quant_config["quant_min"] <= 15 ): - quant_config["bitwidth"] = 4 + quant_config[QCOM_BITWIDTH] = 4 return ( PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, quant_config, @@ -145,7 +155,7 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict): quant_config["dtype"] == torch.int8 and quant_config["quant_max"] - quant_config["quant_min"] <= 15 ): - quant_config["bitwidth"] = 4 + quant_config[QCOM_BITWIDTH] = 4 return ( PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, quant_config, @@ -158,17 +168,17 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict): def get_quant_encoding_conf( self, node: torch.fx.Node, is_input_tensor: bool = False ) -> Tuple[Any, Dict]: - if not node.meta.get("quant_attrs", None): + if not node.meta.get(QCOM_QUANT_ATTRS, None): return ( PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, {}, ) quant_attrs = ( - node.meta["requantize"] - if "requantize" in node.meta and is_input_tensor - else node.meta["quant_attrs"] + node.meta[QCOM_REQUANTIZE] + if QCOM_REQUANTIZE in node.meta and is_input_tensor + else node.meta[QCOM_QUANT_ATTRS] ) - if quant_attrs["encoding"] in PER_CHANNEL_ENCODING: + if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING: return self.make_qnn_per_channel_config(node, quant_attrs) return self.make_qnn_per_tensor_config(quant_attrs) @@ -176,18 +186,18 @@ def get_quant_encoding_conf( def get_quant_tensor_value( self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict ) -> torch.Tensor: - if quant_attrs["encoding"] in PER_TENSOR_ENCODING: + if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING: scale = quant_attrs["scale"] zero_point = quant_attrs["zero_point"] else: # per channel case - scale = quant_attrs["scales"] - zero_point = quant_attrs["zero_points"] + scale = quant_attrs[QCOM_SCALES] + zero_point = quant_attrs[QCOM_ZERO_POINTS] dtype = quant_configs["dtype"] tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly - if quant_configs.get("bitwidth") == 4: + if quant_configs.get(QCOM_BITWIDTH) == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) tensor = torch.bitwise_and(mask, tensor) return tensor @@ -315,7 +325,7 @@ def define_tensor( if quant_configs: tensor = self.get_quant_tensor_value( tensor, - node.meta["quant_attrs"], + node.meta[QCOM_QUANT_ATTRS], quant_configs, ) tensor_wrapper = PyQnnWrapper.TensorWrapper( diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index f613d6b2d0..3e10a1918d 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -132,12 +133,12 @@ def define_node( avg_pool2d_op.AddScalarParam( OpPoolAvg2d.param_rounding_mode, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(mode)}, + {QCOM_DATA: np.uint32(mode)}, ) avg_pool2d_op.AddScalarParam( OpPoolAvg2d.param_count_pad_for_edges, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": count_include_pad}, + {QCOM_DATA: count_include_pad}, ) return avg_pool2d_op diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 9f653fe032..bb68b24289 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpConcat, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -64,8 +65,8 @@ def define_node( if axis < 0: axis += node.meta["val"].dim() - if "axis_order" in node.meta: - axis = node.meta["axis_order"].index(axis) + if QCOM_AXIS_ORDER in node.meta: + axis = node.meta[QCOM_AXIS_ORDER].index(axis) concat_op = PyQnnWrapper.PyQnnOpWrapper( node.name, @@ -78,7 +79,7 @@ def define_node( concat_op.AddScalarParam( OpConcat.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(axis)}, + {QCOM_DATA: np.uint32(axis)}, ) return concat_op diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py index f06e5b3480..0c69a8d333 100644 --- a/backends/qualcomm/builders/op_clamp.py +++ b/backends/qualcomm/builders/op_clamp.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -67,12 +68,12 @@ def define_node( clamp_op.AddScalarParam( OpReluMinMax.param_max_value, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(output_max)}, + {QCOM_DATA: np.float32(output_max)}, ) clamp_op.AddScalarParam( OpReluMinMax.param_min_value, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(output_min)}, + {QCOM_DATA: np.float32(output_min)}, ) return clamp_op diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 98086ce173..4b58edbac6 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import ( @@ -79,7 +80,7 @@ def _add_conv_op_parameter( conv_op.AddScalarParam( OP.param_group, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(groups)}, + {QCOM_DATA: np.uint32(groups)}, ) return conv_op @@ -130,7 +131,7 @@ def _define_conv1d( unsqueeze_op.AddScalarParam( OpExpandDims.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(1)}, + {QCOM_DATA: np.uint32(1)}, ) op_wrapper_list.append(unsqueeze_op) diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py index d02b54aa0e..e734372098 100644 --- a/backends/qualcomm/builders/op_depth_to_space.py +++ b/backends/qualcomm/builders/op_depth_to_space.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpDepthToSpace, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -70,7 +71,7 @@ def define_node( depth_to_space_op.AddScalarParam( OpDepthToSpace.param_mode, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(OpDepthToSpace.Mode.CRD)}, + {QCOM_DATA: np.uint32(OpDepthToSpace.Mode.CRD)}, ) return depth_to_space_op diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index a5d6aae170..8ae3b64fbf 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -71,7 +72,7 @@ def define_node( gather_op.AddScalarParam( OpGather.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, - {"data": np.int32(0)}, + {QCOM_DATA: np.int32(0)}, ) return gather_op diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py index b72e63a64d..196777d628 100644 --- a/backends/qualcomm/builders/op_hardsigmoid.py +++ b/backends/qualcomm/builders/op_hardsigmoid.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpElementWiseNeuron, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -58,19 +59,19 @@ def define_node( hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_operation, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(2)}, + {QCOM_DATA: np.uint32(2)}, ) # The parameter used in Pytorch definition for hardsigmoid hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_alpha, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(1 / 6)}, + {QCOM_DATA: np.float32(1 / 6)}, ) hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_beta, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(1 / 2)}, + {QCOM_DATA: np.float32(1 / 2)}, ) return hardsigmoid_op diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py index 76b13e960c..8d90385277 100644 --- a/backends/qualcomm/builders/op_hardtanh.py +++ b/backends/qualcomm/builders/op_hardtanh.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -66,12 +67,12 @@ def define_node( hardtanh_op.AddScalarParam( OpReluMinMax.param_max_value, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(output_max)}, + {QCOM_DATA: np.float32(output_max)}, ) hardtanh_op.AddScalarParam( OpReluMinMax.param_min_value, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(output_min)}, + {QCOM_DATA: np.float32(output_min)}, ) return hardtanh_op diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 06efbd6e10..18f5b76310 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -91,7 +92,7 @@ def define_node( layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {"data": np.float32(epsilon)}, + {QCOM_DATA: np.float32(epsilon)}, ) layer_norm_op.AddTensorParam( OpLayerNorm.param_axes, diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 9a59352821..17afb21c6d 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -9,6 +9,11 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_SCALES, + QCOM_ZERO_POINTS, +) from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -41,12 +46,14 @@ def define_node( weight_node = node.args[1] if ( - quant_attrs := weight_node.meta.get("quant_attrs") - ) and "scales" in quant_attrs: + quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS) + ) and QCOM_SCALES in quant_attrs: # Dimension of weight is [m, n], per channel quant params is [m] # Change to [m, 1] to fit the tensor.div(s).add(z) - quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1]) - quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1]) + quant_attrs[QCOM_SCALES] = quant_attrs[QCOM_SCALES].reshape([-1, 1]) + quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape( + [-1, 1] + ) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( @@ -62,7 +69,7 @@ def define_node( bias_node = node.args[2] # TODO remove this when qnn sdk support - if "scales" in bias_node.meta.get("quant_attrs", {}): + if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}): print( f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." ) diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index 002dd5bc9b..fdd298f988 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpLogSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -52,8 +53,8 @@ def define_node( if dim < 0: dim = dim % len(input_tensor.shape) - if "axis_order" in node.meta: - dim = node.meta["axis_order"].index(dim) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) # logsoftmax only supports last dimension for now, which is channel in QNN if dim != input_tensor.dim() - 1: @@ -70,6 +71,6 @@ def define_node( log_softmax_op.AddScalarParam( OpLogSoftmax.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(dim)}, + {QCOM_DATA: np.uint32(dim)}, ) return log_softmax_op diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 6c8900b377..586556621b 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -134,7 +135,7 @@ def define_node( max_pool2d_op.AddScalarParam( OpPoolMax2d.param_rounding_mode, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(mode)}, + {QCOM_DATA: np.uint32(mode)}, ) return max_pool2d_op diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 18d48fa91c..e60e3e790b 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpReduceMean, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -42,9 +43,9 @@ def define_node( mean_dims = [ mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] - if "axis_order" in node.meta: + if QCOM_AXIS_ORDER in node.meta: mean_dims = [ - node.meta["axis_order"].index(mean_dim) for mean_dim in mean_dims + node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims ] mean_dims_shape = [len(mean_dims)] @@ -77,7 +78,7 @@ def define_node( reduce_mean_op.AddScalarParam( OpReduceMean.param_keep_dims, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": keep_dims}, + {QCOM_DATA: keep_dims}, ) return reduce_mean_op diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 4da23aa870..9ca385ff85 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor from .qnn_constants import OpPad, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -58,8 +59,8 @@ def define_node( (np.array([(0, 0)] * zero_amounts), pad_amount) ).astype(np.uint32) - if "axis_order" in node.meta: - pad_amount = np.transpose(pad_amount, node.meta["axis_order"]) + if QCOM_AXIS_ORDER in node.meta: + pad_amount = np.transpose(pad_amount, node.meta[QCOM_AXIS_ORDER]) pad_amount_val = node.args[2] pad_op = PyQnnWrapper.PyQnnOpWrapper( @@ -74,13 +75,13 @@ def define_node( pad_op.AddScalarParam( OpPad.param_scheme, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(OpPad.Scheme.CONSTANT)}, + {QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)}, ) pad_op.AddScalarParam( OpPad.param_pad_constant_value, QNN_TENSOR_TYPE_MAP[type(pad_amount_val)], - {"data": pad_amount_val}, + {QCOM_DATA: pad_amount_val}, ) pad_op.AddTensorParam( diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index f30ffbf528..b4153d458b 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -8,6 +8,7 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor @@ -65,14 +66,14 @@ def define_node( {}, # kwargs ) - if pow_quant_attrs := node.meta.get("quant_attrs"): + if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = pow_quant_attrs.copy() quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"] quant_attrs["zero_point"] = 0 if scalar >= 0 else quant_attrs["quant_max"] quant_attrs["scale"] = ( scalar / quant_range if scalar >= 0 else -scalar / quant_range ) - scalar_node.meta["quant_attrs"] = quant_attrs + scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs scalar_tensor_wrapper = self.define_tensor( scalar_node, diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py index 8305b0c965..fc0c6b9232 100644 --- a/backends/qualcomm/builders/op_prelu.py +++ b/backends/qualcomm/builders/op_prelu.py @@ -8,6 +8,10 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_QUANT_ATTRS, +) from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import get_parameter, NodeVisitor, register_node_visitor @@ -53,8 +57,8 @@ def define_node( coeff_tensor = coeff_tensor.index_fill( 1, torch.tensor([i]), coeff[i] ) - if "axis_order" in input_node.meta: - axis_order = input_node.meta["axis_order"] + if QCOM_AXIS_ORDER in input_node.meta: + axis_order = input_node.meta[QCOM_AXIS_ORDER] coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() # simple min-max quantization coeff = torch.max(coeff).item() @@ -71,13 +75,13 @@ def define_node( (), # args {}, # kwargs ) - if pow_quant_attrs := node.meta.get("quant_attrs"): + if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = pow_quant_attrs.copy() quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"] # coeff is guaranteed to be positive quant_attrs["zero_point"] = 0 quant_attrs["scale"] = coeff / quant_range - scalar_node.meta["quant_attrs"] = quant_attrs + scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs scalar_tensor_wrapper = self.define_tensor( scalar_node, diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py index 10e1e1be2f..9d53d65571 100644 --- a/backends/qualcomm/builders/op_quantize.py +++ b/backends/qualcomm/builders/op_quantize.py @@ -8,6 +8,7 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpQuantize, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -34,11 +35,11 @@ def define_node( ) quant_input_tensors.append(inp_tensor_wrapper) - node.meta["quant_attrs"] = {"encoding": node.target} + node.meta[QCOM_QUANT_ATTRS] = {QCOM_ENCODING: node.target} arg_schemas = list(node.target._schema.arguments)[1:] for i, arg_schema in enumerate(arg_schemas): name = arg_schema.name - node.meta["quant_attrs"][name] = node.args[i + 1] + node.meta[QCOM_QUANT_ATTRS][name] = node.args[i + 1] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index 1db9e7d38d..fdeec3845e 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -81,7 +82,7 @@ def define_node( stride_slice_op.AddScalarParam( OpStridedSlice.param_shrink_axes, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(math.pow(2, dim))}, + {QCOM_DATA: np.uint32(math.pow(2, dim))}, ) return stride_slice_op diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index 37459a5ff4..ae4c89bbb9 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpSoftmax, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -50,8 +51,8 @@ def define_node( dim = cast(int, node.args[1]) if dim < 0: dim = dim % len(input_tensor.shape) - if "axis_order" in node.meta: - dim = node.meta["axis_order"].index(dim) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) # softmax only supports last dimension for now, which is channel in QNN if dim != input_tensor.dim() - 1: @@ -68,7 +69,7 @@ def define_node( softmax_op.AddScalarParam( OpSoftmax.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(dim)}, + {QCOM_DATA: np.uint32(dim)}, ) return softmax_op diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py index 4e14590074..a9b61c520e 100644 --- a/backends/qualcomm/builders/op_space_to_depth.py +++ b/backends/qualcomm/builders/op_space_to_depth.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpSpaceToDepth, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -70,7 +71,7 @@ def define_node( space_to_depth_op.AddScalarParam( OpSpaceToDepth.param_mode, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(OpSpaceToDepth.Mode.CRD)}, + {QCOM_DATA: np.uint32(OpSpaceToDepth.Mode.CRD)}, ) return space_to_depth_op diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 03d19b1a5a..58503ff3f8 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -67,8 +68,8 @@ def define_node( if dim < 0: dim = dim % len(input_tensor.shape) - if "axis_order" in node.meta: - dim = node.meta["axis_order"].index(dim) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) split_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, @@ -88,6 +89,6 @@ def define_node( split_op.AddScalarParam( OpSplit.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {"data": np.uint32(dim)}, + {QCOM_DATA: np.uint32(dim)}, ) return split_op diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py index 26cc262462..abe35c2244 100644 --- a/backends/qualcomm/builders/op_sum_int_list.py +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -9,6 +9,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpReduceSum, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -41,8 +42,10 @@ def define_node( # sum dims sum_dims = cast(List[int], node.args[1]) sum_dims = [sum_dim % len(input_node.meta["val"].shape) for sum_dim in sum_dims] - if "axis_order" in node.meta: - sum_dims = [node.meta["axis_order"].index(sum_dim) for sum_dim in sum_dims] + if QCOM_AXIS_ORDER in node.meta: + sum_dims = [ + node.meta[QCOM_AXIS_ORDER].index(sum_dim) for sum_dim in sum_dims + ] sum_dims_shape = [len(sum_dims)] output_tensor = self.get_tensor(node, node) @@ -75,6 +78,6 @@ def define_node( sum_op.AddScalarParam( OpReduceSum.param_keep_dims, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": keep_dims}, + {QCOM_DATA: keep_dims}, ) return sum_op diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index 8f3c0276cc..e17ee2790b 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -8,6 +8,7 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -35,7 +36,12 @@ def is_cast_node(self, node): input_node = node.args[0] # Not a case which has two quant node, no need to consider the convert op - if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]): + if not all( + [ + input_node.meta.get(QCOM_QUANT_ATTRS), + node.meta.get(QCOM_QUANT_ATTRS), + ] + ): return True input_tensor = self.get_tensor(input_node, node) diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index ebe9f2448f..20e30da335 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -10,6 +10,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_INSERTED_PERMUTE from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -28,7 +29,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: input_node = node.args[0] - permute_node = input_node if "qnn_permute" in node.meta else node + permute_node = input_node if QCOM_INSERTED_PERMUTE in node.meta else node input_tensor = self.get_tensor(input_node, permute_node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index 4eb4649073..53291786a8 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -8,6 +8,7 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpResizeBilinear, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -55,12 +56,12 @@ def define_node( reisze_bilinear_op.AddScalarParam( OpResizeBilinear.param_align_corners, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": node.args[2]}, + {QCOM_DATA: node.args[2]}, ) reisze_bilinear_op.AddScalarParam( OpResizeBilinear.param_half_pixel_centers, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": not node.args[2]}, + {QCOM_DATA: not node.args[2]}, ) return reisze_bilinear_op diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py index d3d83e7698..75e61d77e5 100644 --- a/backends/qualcomm/builders/op_upsample_nearest2d.py +++ b/backends/qualcomm/builders/op_upsample_nearest2d.py @@ -8,6 +8,7 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpResizeNearestNeighbor, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -55,12 +56,12 @@ def define_node( reisze_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_align_corners, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": False}, + {QCOM_DATA: False}, ) reisze_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_half_pixel_centers, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - {"data": True}, + {QCOM_DATA: True}, ) return reisze_nearest_op diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 0c5b25284e..c3afc23dae 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -11,6 +11,7 @@ import torch from executorch.backends.qualcomm.builders import node_visitor from executorch.backends.qualcomm.qnn_preprocess import QnnBackend +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option from executorch.exir.backend.backend_details import CompileSpec @@ -147,7 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu if hasattr(node, "meta"): # pop certain keys in meta for not affecting the passes in compilation # TODO: need to put property name in common definitions - node.meta.pop("axis_order", "") + node.meta.pop(QCOM_AXIS_ORDER, "") del self.op_support_checker return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags diff --git a/backends/qualcomm/passes/annotate_and_quant_scalar.py b/backends/qualcomm/passes/annotate_and_quant_scalar.py index 52ab47a9c2..5f111ee9c8 100644 --- a/backends/qualcomm/passes/annotate_and_quant_scalar.py +++ b/backends/qualcomm/passes/annotate_and_quant_scalar.py @@ -9,6 +9,7 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -40,7 +41,6 @@ class AnnotateAndQuantScalar(ExportPass): "mul", "truediv", ] - quant_attrs_key = "quant_attrs" def __init__(self, edge_program: torch.export.ExportedProgram): super(AnnotateAndQuantScalar, self).__init__() @@ -81,7 +81,7 @@ def _annotate_scalar_node( ]: return - be_annotated_node.meta[self.quant_attrs_key] = quant_attrs + be_annotated_node.meta[QCOM_QUANT_ATTRS] = quant_attrs def _traverse_binary_node(self, graph_module: torch.fx.GraphModule): src_partitions = get_source_partitions( @@ -91,7 +91,7 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule): for src_partition in src_partitions: output = src_partition.output_nodes[0] if ( - output.meta.get(self.quant_attrs_key) + output.meta.get(QCOM_QUANT_ATTRS) and len(src_partition.input_nodes) == 1 ): dq_node = src_partition.input_nodes[0] diff --git a/backends/qualcomm/passes/annotate_decomposed.py b/backends/qualcomm/passes/annotate_decomposed.py index c82fee82a6..a8a757ce9b 100644 --- a/backends/qualcomm/passes/annotate_decomposed.py +++ b/backends/qualcomm/passes/annotate_decomposed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -28,7 +29,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule): q_node = src_partition.input_nodes[0].args[0] quant_attrs = get_quant_attrs(self.edge_program, q_node) for n in src_partition.nodes: - n.meta["quant_attrs"] = quant_attrs.copy() + n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() def _annotate_stack(self, graph_module: torch.fx.GraphModule): partitions = get_source_partitions(graph_module.graph, [torch.stack]) @@ -40,7 +41,7 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule): self.edge_program, list(output.users)[0] ) for n in src_partition.nodes: - n.meta["quant_attrs"] = quant_attrs.copy() + n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() def call(self, graph_module: torch.fx.GraphModule): self._annotate_unbind(graph_module) diff --git a/backends/qualcomm/passes/annotate_quant_attrs.py b/backends/qualcomm/passes/annotate_quant_attrs.py index 64b1a663d5..199d26b026 100644 --- a/backends/qualcomm/passes/annotate_quant_attrs.py +++ b/backends/qualcomm/passes/annotate_quant_attrs.py @@ -8,6 +8,13 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter +from executorch.backends.qualcomm.utils.constants import ( + QCOM_ENCODING, + QCOM_QUANT_ATTRS, + QCOM_REQUANTIZE, + QCOM_SCALES, + QCOM_ZERO_POINTS, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -30,11 +37,11 @@ def _annotate_source_nodes( if quant_node.args[0].target == operator.getitem: getitem_node = quant_node.args[0] - getitem_node.meta["quant_attrs"] = quant_attrs + getitem_node.meta[QCOM_QUANT_ATTRS] = quant_attrs source_n = getitem_node.args[0] else: source_n = quant_node.args[0] - source_n.meta["quant_attrs"] = quant_attrs + source_n.meta[QCOM_QUANT_ATTRS] = quant_attrs def _expand(self, tensor, dim, axis) -> torch.Tensor: tensor = tensor[(...,) + (None,) * (dim - 1)] @@ -62,18 +69,18 @@ def _annotate_requant(self, n): # TODO: Store multiple pairs of requantize attributes when we have an op builder # that has multiple outputs that requires quant attributes. if q_attrs["dtype"] != dq_attrs["dtype"]: - dq_attrs["encoding"] = q_attrs["encoding"] - n.args[0].meta["requantize"] = dq_attrs + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] + n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs # Dequant all the fold_quant parameters back to fp32. # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. def _dequant_fold_params(self, n, quant_attrs, param): - if quant_attrs["encoding"] in [ + if quant_attrs[QCOM_ENCODING] in [ exir_ops.edge.quantized_decomposed.dequantize_per_channel.default ]: dim, axis = param.dim(), quant_attrs["axis"] - scales = self._expand(quant_attrs["scales"], dim, axis) - offsets = self._expand(quant_attrs["zero_points"], dim, axis) + scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis) + offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() set_parameter(param, n.args[0], self.edge_program) else: diff --git a/backends/qualcomm/passes/build_quant_io.py b/backends/qualcomm/passes/build_quant_io.py index 7a5556fcdd..b627b9b305 100644 --- a/backends/qualcomm/passes/build_quant_io.py +++ b/backends/qualcomm/passes/build_quant_io.py @@ -4,12 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec -from .utils import q_io_key - class BuildQuantIo(ExportPass): """ @@ -38,12 +37,12 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: assert len(call_delegate) == 1 spec = [] for n in graph_module.graph.nodes: - if q_io_key in n.meta: - n.meta["val"] = n.meta["val"].to(dtype=n.meta[q_io_key]) + if QCOM_QUANTIZED_IO in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) if n.op == "call_function" and "getitem" in n.name: fake_tensor = n.meta["val"] - if q_io_key in n.meta: - fake_tensor = fake_tensor.to(dtype=n.meta[q_io_key]) + if QCOM_QUANTIZED_IO in n.meta: + fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO]) spec.append(self._make_spec(fake_tensor)) call_delegate[0].meta["spec"] = tuple(spec) diff --git a/backends/qualcomm/passes/convert_to_linear.py b/backends/qualcomm/passes/convert_to_linear.py index 890ac697ef..8de89f8f40 100644 --- a/backends/qualcomm/passes/convert_to_linear.py +++ b/backends/qualcomm/passes/convert_to_linear.py @@ -7,6 +7,7 @@ from typing import Callable, List import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.backends.transforms.addmm_mm_to_linear import ( apply_addmm_mm_to_linear_transform, ) @@ -86,7 +87,7 @@ def _convert_to_linear( # qnn htp does not support keepdim, the view_copy(reshape) should exist for now if self._get_original_input(inputs, input_node).target in dq_ops: - input_node.meta["quant_attrs"] = get_quant_attrs( + input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( gm, self._get_original_input(inputs, input_node).args[0] ) args = [input_node, weight_node] @@ -100,7 +101,7 @@ def _convert_to_linear( ) linear_node.meta = fn_node.meta if list(output.users)[0].target in q_ops: - linear_node.meta["quant_attrs"] = get_quant_attrs( + linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( gm, list(output.users)[0] ) for user in fn_node.users.copy(): @@ -144,12 +145,12 @@ def _convert_to_linear( unsqueeze_view_copy_node.meta = output.args[0].meta for user in output_users: user.replace_input_with(linear_node, unsqueeze_view_copy_node) - if "quant_attrs" in linear_node.meta: - squeeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[ - "quant_attrs" + if QCOM_QUANT_ATTRS in linear_node.meta: + squeeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[ + QCOM_QUANT_ATTRS ] - unsqueeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[ - "quant_attrs" + unsqueeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[ + QCOM_QUANT_ATTRS ] def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: diff --git a/backends/qualcomm/passes/fuse_consecutive_transpose.py b/backends/qualcomm/passes/fuse_consecutive_transpose.py index b2351fe9e8..c81818e00e 100644 --- a/backends/qualcomm/passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/passes/fuse_consecutive_transpose.py @@ -6,7 +6,7 @@ import torch -from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform +from executorch.backends.qualcomm.utils.constants import QCOM_INSERTED_PERMUTE from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -67,12 +67,9 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # copy metadata permute_node.meta = output_node.meta - # Without inserted_permute_tag, we might obtain wrong input shape - if [ - pn.meta.get(LayoutTransform.inserted_permute_tag) - for pn in self.nodes - ]: - permute_node.meta[LayoutTransform.inserted_permute_tag] = True + # Without "qnn_permute", we might obtain wrong input shape + if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: + permute_node.meta[QCOM_INSERTED_PERMUTE] = True # clear current stack self.nodes = [] diff --git a/backends/qualcomm/passes/insert_io_qdq.py b/backends/qualcomm/passes/insert_io_qdq.py index 0bec89088d..668f76cd69 100644 --- a/backends/qualcomm/passes/insert_io_qdq.py +++ b/backends/qualcomm/passes/insert_io_qdq.py @@ -8,10 +8,15 @@ import torch from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.utils.constants import ( + QCOM_ENCODING, + QCOM_QUANT_ATTRS, + QCOM_QUANTIZED_IO, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import q_io_key, q_ops +from .utils import q_ops class InsertIOQDQ(ExportPass): @@ -64,7 +69,7 @@ def _create_node( # check if there has a specified quant_attrs # if not, use the existent info. from current node if quant_attrs is None: - quant_attrs = node.meta.get("quant_attrs") + quant_attrs = node.meta.get(QCOM_QUANT_ATTRS) inserted_node = graph_module.graph.create_node( "call_function", @@ -73,7 +78,7 @@ def _create_node( ) meta_val = node.meta["val"] if target in self.q_dq_map: - inserted_node.meta["quant_attrs"] = node.meta.pop("quant_attrs") + inserted_node.meta[QCOM_QUANT_ATTRS] = node.meta.pop(QCOM_QUANT_ATTRS) meta_val = meta_val.to(quant_attrs["dtype"]) inserted_node.meta["val"] = meta_val @@ -112,26 +117,28 @@ def _insert_dequant_node( def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: # do nothing when a node is expected to output a quant tensor - if n.meta.get(q_io_key): + if n.meta.get(QCOM_QUANTIZED_IO): continue # insert q after input or fold mix_quantization dq if applicable if ( n.op == "placeholder" - and n.meta.get("quant_attrs") + and n.meta.get(QCOM_QUANT_ATTRS) and not is_parameter(n, self.edge_program) ): self._insert_quant_node( - graph_module, n, n.meta["quant_attrs"]["encoding"] + graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] ) # insert dq before output or fold mix_quantization q if applicable users = list(n.users.keys()) - if n.meta.get("quant_attrs") and any(user.op == "output" for user in users): + if n.meta.get(QCOM_QUANT_ATTRS) and any( + user.op == "output" for user in users + ): self._insert_dequant_node( graph_module, n, - self.q_dq_map[n.meta["quant_attrs"]["encoding"]], + self.q_dq_map[n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]], ) def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/qualcomm/passes/insert_requantize.py b/backends/qualcomm/passes/insert_requantize.py index 4e79a4bda6..417d3b85b0 100644 --- a/backends/qualcomm/passes/insert_requantize.py +++ b/backends/qualcomm/passes/insert_requantize.py @@ -6,11 +6,15 @@ import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANTIZED_IO, + QCOM_REQUANTIZE, +) + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import q_io_key - class InsertRequantize(ExportPass): """ @@ -50,16 +54,16 @@ def _single_output_annotation( ) inserted_n.meta["val"] = n.meta["val"] - inserted_n.meta["quant_attrs"] = n.meta.pop("requantize") - if n.meta.get(q_io_key): - inserted_n.meta[q_io_key] = n.meta[q_io_key] + inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE) + if n.meta.get(QCOM_QUANTIZED_IO): + inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO] for user in users: user.replace_input_with(n, inserted_n) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: - if "requantize" in n.meta: + if QCOM_REQUANTIZE in n.meta: ( self._single_output_annotation(graph_module, n) if isinstance( diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index f7c9142079..bdee2c8196 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -9,6 +9,12 @@ import torch from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_INSERTED_PERMUTE, + QCOM_QUANT_ATTRS, + QCOM_REQUANTIZE, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.sym_util import eval_shape @@ -66,9 +72,6 @@ class LayoutTransform(ExportPass): _operator.getitem, } - layout_transformed_tag = "axis_order" - inserted_permute_tag = "qnn_permute" - layout_type = { 1: ("N", "N"), 2: ("NC", "NC"), @@ -101,18 +104,18 @@ def mark_as_transformed(self, node: torch.fx.Node) -> None: f"got {getitem_node.target.__name__}" ) index = getitem_node.args[1] - node.meta[self.layout_transformed_tag] = self.get_axis_order( + node.meta[QCOM_AXIS_ORDER] = self.get_axis_order( eval_shape(node.meta["val"][index].shape) ) else: - node.meta[self.layout_transformed_tag] = self.get_axis_order( + node.meta[QCOM_AXIS_ORDER] = self.get_axis_order( eval_shape(node.meta["val"].shape) ) def is_transformed_node(self, node: torch.fx.Node) -> bool: if not hasattr(node, "meta"): return False - return self.layout_transformed_tag in node.meta + return QCOM_AXIS_ORDER in node.meta def is_layout_sensitive(self, node: torch.fx.Node) -> bool: return node.target in self.layout_sensitive_ops @@ -126,7 +129,7 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool: if len(node.args) < 3 or not node.args[2]: return False if node.target in self.qdq_opset: - return "requantize" in node.meta + return QCOM_REQUANTIZE in node.meta return node.target in self.layout_agnostic_ops def is_edge_condition(self, node): @@ -139,7 +142,7 @@ def is_edge_condition(self, node): node.op == "get_attr", ( node.target == exir_ops.edge.aten.permute_copy.default - and node.meta.get(self.inserted_permute_tag, False) + and node.meta.get(QCOM_INSERTED_PERMUTE, False) ), ( node.op != "output" @@ -178,9 +181,9 @@ def insert_node(self, graph_module, node, revert_layout: bool) -> None: ), ) permute.meta["val"] = tensor - permute.meta["quant_attrs"] = node.meta.get("quant_attrs") + permute.meta[QCOM_QUANT_ATTRS] = node.meta.get(QCOM_QUANT_ATTRS) # we need this to check the annotation boundary - permute.meta[self.inserted_permute_tag] = True + permute.meta[QCOM_INSERTED_PERMUTE] = True for user in users: user.replace_input_with(node, permute) diff --git a/backends/qualcomm/passes/recompose_pixel_unshuffle.py b/backends/qualcomm/passes/recompose_pixel_unshuffle.py index 57ef9bd077..cadc310bbb 100644 --- a/backends/qualcomm/passes/recompose_pixel_unshuffle.py +++ b/backends/qualcomm/passes/recompose_pixel_unshuffle.py @@ -93,7 +93,9 @@ def call(self, graph_module: torch.fx.GraphModule): op = self.op pixel_unshuffle_node = graph.create_node( - "call_function", op, (input_node, int(downscale_factor)) + "call_function", + op, + (input_node, int(downscale_factor)), ) users = output_node.users.copy() for user in users: diff --git a/backends/qualcomm/passes/utils.py b/backends/qualcomm/passes/utils.py index c97f2b8f53..ac6525ae76 100755 --- a/backends/qualcomm/passes/utils.py +++ b/backends/qualcomm/passes/utils.py @@ -6,12 +6,10 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING from executorch.exir.dialects._ops import ops as exir_ops -# TODO, Move all Qualcomm specific keys to here, like "quant_attrs" -q_io_key = "q_tensor_io" - q_ops = { exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -43,5 +41,5 @@ def get_quant_attrs( value = get_parameter(attr_n, edge_program) quant_attrs[quant_attr_keys[i - 1]] = value - quant_attrs["encoding"] = quant_node.target + quant_attrs[QCOM_ENCODING] = quant_node.target return quant_attrs diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 35b4ff03d0..508a027da6 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -19,6 +19,12 @@ TestQNN, to_backend, ) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_ANNOTATION, + QCOM_MODULE, + QCOM_QUANT_DTYPE, + QCOM_SAMPLE_INPUTS, +) from executorch.backends.qualcomm.utils.utils import ( canonicalize_program, @@ -114,22 +120,22 @@ def test_qnn_backend_conv2d(self): def test_qnn_backend_element_wise_add(self): test_comb = [ { - "module": [Add()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Add()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [AddConstantFloat()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [AddConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -143,22 +149,22 @@ def test_qnn_backend_element_wise_div(self): eps = 1e-03 test_comb = [ { - "module": [Div()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Div()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], }, { - "module": [DivConstantFloat()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [DivConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -166,26 +172,26 @@ def test_qnn_backend_element_wise_div(self): def test_qnn_backend_element_wise_mul(self): test_comb = [ { - "module": [Mul()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Mul()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [MulConstantFloat()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [MulConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [MulScalar()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [MulScalar()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -200,22 +206,22 @@ def test_qnn_backend_element_wise_sqrt(self): def test_qnn_backend_element_wise_sub(self): test_comb = [ { - "module": [Sub()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Sub()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [SubConstantFloat()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [SubConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -269,19 +275,19 @@ def test_qnn_backend_layer_norm(self): def test_qnn_backend_leaky_relu(self): test_comb = [ { - "module": [LeakyReLUDefault()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [LeakyReLUDefault()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [LeakyReLUCustom(0.05)], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -340,19 +346,19 @@ def test_qnn_backend_pow_tensor_scalar(self): def test_qnn_backend_prelu(self): test_comb = [ { - "module": [PReLUDefault()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [PReLUDefault()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [PReLUPerChannel(5)], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [PReLUPerChannel(5)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): self.lower_module_and_test_output(module, sample_input) index += 1 @@ -673,22 +679,22 @@ def test_qnn_backend_conv2d(self): def test_qnn_backend_element_wise_add(self): test_comb = [ { - "module": [Add()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Add()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [AddConstantFloat(), AddConstantLong()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [AddConstantFloat(), AddConstantLong()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -704,22 +710,22 @@ def test_qnn_backend_element_wise_div(self): eps = 1e-03 test_comb = [ { - "module": [Div()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Div()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], }, { - "module": [DivConstantFloat(), DivConstantLong()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [DivConstantFloat(), DivConstantLong()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -728,26 +734,26 @@ def test_qnn_backend_element_wise_div(self): def test_qnn_backend_element_wise_mul(self): test_comb = [ { - "module": [Mul()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Mul()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [MulConstantFloat(), MulConstantLong()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [MulConstantFloat(), MulConstantLong()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [MulScalar()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [MulScalar()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -764,22 +770,22 @@ def test_qnn_backend_element_wise_sqrt(self): def test_qnn_backend_element_wise_sub(self): test_comb = [ { - "module": [Sub()], # noqa: F405 - "sample_inputs": [ + QCOM_MODULE: [Sub()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), torch.randn([4, 1])), ], }, { - "module": [SubConstantFloat(), SubConstantLong()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [SubConstantFloat(), SubConstantLong()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -842,19 +848,19 @@ def test_qnn_backend_layer_norm(self): def test_qnn_backend_leaky_relu(self): test_comb = [ { - "module": [LeakyReLUDefault()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [LeakyReLUDefault()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [LeakyReLUCustom(0.05)], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -919,19 +925,19 @@ def test_qnn_backend_pow_tensor_scalar(self): def test_qnn_backend_prelu(self): test_comb = [ { - "module": [PReLUDefault()], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [PReLUDefault()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - "module": [PReLUPerChannel(5)], # noqa: F405 - "sample_inputs": [(torch.randn(2, 5, 1, 3),)], + QCOM_MODULE: [PReLUPerChannel(5)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] index = 0 for comb in test_comb: - for module in comb["module"]: - for sample_input in comb["sample_inputs"]: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -1128,48 +1134,48 @@ def test_qnn_backend_view_permute_matmul(self): def test_qnn_backend_example_models(self): instances = [ { - "module": DeepLabV3ResNet101Model(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: DeepLabV3ResNet101Model(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, { - "module": EdsrModel(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: EdsrModel(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, { - "module": InceptionV3Model(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: InceptionV3Model(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, { - "module": InceptionV4Model(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: InceptionV4Model(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, # The module of llama is changing frequently. Reopen it when it's stable - # {"module": Llama2Model(), "annotation": (), "quant_dtype": QuantDtype.use_8a8w}, + # {QCOM_MODULE: Llama2Model(), QCOM_ANNOTATION: (), QCOM_QUANT_DTYPE: QuantDtype.use_8a8w}, { - "module": MV2Model(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: MV2Model(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, { - "module": MV3Model(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: MV3Model(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, # only works on QNN 2.12 so far - # { 'module': MobileBertModelExample(), 'annotation': (), "quant_dtype": QuantDtype.use_8a8w }, + # { 'module': MobileBertModelExample(), 'annotation': (), QCOM_QUANT_DTYPE: QuantDtype.use_8a8w }, { - "module": TorchVisionViTModel(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: TorchVisionViTModel(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, { - "module": Wav2LetterModel(), - "annotation": (), - "quant_dtype": QuantDtype.use_8a8w, + QCOM_MODULE: Wav2LetterModel(), + QCOM_ANNOTATION: (), + QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, ] expected_partitions = [ @@ -1189,13 +1195,13 @@ def test_qnn_backend_example_models(self): disable_validation() for i, instance in enumerate(instances): with self.subTest(i=i): - module = instance["module"].get_eager_model().eval() - sample_input = instance["module"].get_example_inputs() + module = instance[QCOM_MODULE].get_eager_model().eval() + sample_input = instance[QCOM_MODULE].get_example_inputs() module = self.get_qdq_module( module, sample_input, - custom_quant_annotations=instance["annotation"], - quant_dtype=instance["quant_dtype"], + custom_quant_annotations=instance[QCOM_ANNOTATION], + quant_dtype=instance[QCOM_QUANT_DTYPE], ) self.lower_module_and_test_output( module, diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py new file mode 100644 index 0000000000..58538eb91e --- /dev/null +++ b/backends/qualcomm/utils/constants.py @@ -0,0 +1,29 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Qualcomm specific key + +# constants in backends/qualcomm/passes & backends/qualcomm/builders +QCOM_AXIS_ORDER = "axis_order" +QCOM_BITWIDTH = "bitwidth" +QCOM_DATA = "data" +QCOM_ENCODING = "encoding" +QCOM_INSERTED_PERMUTE = "qnn_permute" +QCOM_QUANTIZED_IO = "q_tensor_io" +QCOM_QUANT_ATTRS = "quant_attrs" +QCOM_REQUANTIZE = "requantize" +QCOM_SCALES = "scales" +QCOM_SCALE_OFFSET = "scale_offset" +QCOM_ZERO_POINTS = "zero_points" + +# constants in backends/qualcomm/tests +QCOM_ANNOTATION = "annotation" +QCOM_MODULE = "module" +QCOM_QUANT_DTYPE = "quant_dtype" +QCOM_SAMPLE_INPUTS = "sample_inputs" + +# constants in backends/qualcomm/utils +QCOM_QNN_COMPILE_SPEC = "qnn_compile_spec" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index cf656b106f..85b965a146 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -55,6 +55,7 @@ convert_to_flatbuffer, convert_to_option, ) +from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC from executorch.exir import ExirExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule @@ -65,9 +66,6 @@ from torch.library import Library -QNN_COMPILE_SPEC = "qnn_compile_spec" - - def qnn_capture_config(): return exir.CaptureConfig(enable_aot=True) @@ -154,7 +152,7 @@ def process_lowered_module(module): def update_program(max_sf_buf_size, module_map): def set_spec(module, options): - spec = CompileSpec(QNN_COMPILE_SPEC, convert_to_flatbuffer(options)) + spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, convert_to_flatbuffer(options)) if isinstance(module, ExportedProgram): module.compile_specs[0] = spec else: @@ -321,7 +319,7 @@ def generate_qnn_executorch_option( compiler_specs: List[CompileSpec], ) -> bytes: for compiler_spec in compiler_specs: - if compiler_spec.key == QNN_COMPILE_SPEC: + if compiler_spec.key == QCOM_QNN_COMPILE_SPEC: qnn_compile_spec_buffer = compiler_spec.value else: raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}") @@ -454,5 +452,7 @@ def generate_qnn_executorch_compiler_spec( qnn_executorch_options.is_from_context_binary = is_from_context_binary return [ - CompileSpec(QNN_COMPILE_SPEC, convert_to_flatbuffer(qnn_executorch_options)) + CompileSpec( + QCOM_QNN_COMPILE_SPEC, convert_to_flatbuffer(qnn_executorch_options) + ) ] diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py index c5214ea272..79cf5606d6 100644 --- a/examples/qualcomm/llama2/llama.py +++ b/examples/qualcomm/llama2/llama.py @@ -15,13 +15,13 @@ from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.passes.build_quant_io import BuildQuantIo -from executorch.backends.qualcomm.passes.utils import q_io_key from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.quantizer.utils import get_16a4w_qnn_ptq_config from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) +from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO from executorch.backends.qualcomm.utils.utils import ( capture_program, convert_linear_to_conv2d, @@ -260,14 +260,14 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): and len(users := list(n.users)) == 1 and users[0].meta["val"].size()[-2:] in input_cache_shape ): - n.meta[q_io_key] = kv_type + n.meta[QCOM_QUANTIZED_IO] = kv_type elif n.op == "output": for a in n.args[0]: if ( a.meta["val"].flatten().size()[0] == self.llama_meta["get_head_dim"] ): - a.meta[q_io_key] = kv_type + a.meta[QCOM_QUANTIZED_IO] = kv_type def quantize(self, quant_dtype, custom_annotations=()): self.quant_dtype = quant_dtype