diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e1c903302c..b4bb809b85 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -31,6 +31,7 @@ from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, + QuantizeFullArgument, ) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -84,6 +85,7 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(QuantizeFullArgument()) self.add_pass( FoldAndAnnotateQParamsPass( [ @@ -92,6 +94,7 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.full.default, ] ) ) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 24d1a03395..6ba72eb102 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -15,6 +15,9 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node +q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default +dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ @@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]): self.targeted_ops = targeted_ops def call(self, graph_module: GraphModule) -> PassResult: - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: @@ -145,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, True) + + +class QuantizeFullArgument(ExportPass): + """ + Make sure the fill_value for full.default is quantized. This pass needs to be run before + the folding pass above to make sure that the retraced output of the full.default op is + the right dtype. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.target != exir_ops.edge.aten.full.default: + continue + + # Make sure we have a quantized operator + user = list(n.users)[0] + if user.target != q_op: + continue + + qargs = QuantArgs.from_operator(user.target, user.args) + if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: + # replace the node arg with a quantized dito and also set dtype + # to get the right output according to the Edge IR specification: + # exir/dialects/edge/edge.yaml:3596 + quantized_full_value = qargs.quantize_value(n.args[1]).item() + n.update_arg(1, quantized_full_value) + n.update_kwarg("dtype", qargs.dtype) + modified = True + + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index d2bc1377ce..23a13dd486 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,10 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - quantize_value, -) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -41,19 +37,14 @@ def define_node( shape = tosa_shape(inputs[0].special, output.dim_order) value = inputs[1].number - if is_quant_node: - qargs = get_quant_arg_downstream(list(node.users)[0]) - qvalue = quantize_value(value, qargs) - dtype = ts.DType.INT8 - data = np.full(shape, qvalue, dtype=np.int8) + + if output.dtype == ts.DType.INT8: + fill_dtype = np.int8 else: - assert ( - output.dtype == ts.DType.FP32 - ), "'Full' currently only supports FP32 for unquantized models." - dtype = ts.DType.FP32 - data = np.full(shape, value, dtype=np.float32) + fill_dtype = np.float32 + data = np.full(shape, value, dtype=fill_dtype) - tosa_graph.addConst(shape, dtype, data, node.name + "full-const") + tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const") tosa_graph.addOperator( ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] )