Skip to content

Commit

Permalink
Add full operator to fold dq/q handling
Browse files Browse the repository at this point in the history
Signed-off-by: Per Åstrand <[email protected]>
Change-Id: I39d11cff0ef78df08e67f216b8e0bb86af9fac26
  • Loading branch information
per authored and freddan80 committed Dec 16, 2024
1 parent e24d503 commit 47c2f2e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 17 deletions.
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
)
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand Down Expand Up @@ -84,6 +85,7 @@ def transform_to_backend_pipeline(
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
Expand All @@ -92,6 +94,7 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.full.default,
]
)
)
Expand Down
38 changes: 36 additions & 2 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node

q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Expand Down Expand Up @@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]):
self.targeted_ops = targeted_ops

def call(self, graph_module: GraphModule) -> PassResult:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
Expand Down Expand Up @@ -145,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult:

graph_module.recompile()
return PassResult(graph_module, True)


class QuantizeFullArgument(ExportPass):
"""
Make sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
"""

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target != exir_ops.edge.aten.full.default:
continue

# Make sure we have a quantized operator
user = list(n.users)[0]
if user.target != q_op:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)
modified = True

return PassResult(graph_module, modified)
21 changes: 6 additions & 15 deletions backends/arm/operators/op_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_downstream,
quantize_value,
)
from executorch.backends.arm.tosa_utils import tosa_shape
from torch.fx import Node

Expand All @@ -41,19 +37,14 @@ def define_node(
shape = tosa_shape(inputs[0].special, output.dim_order)

value = inputs[1].number
if is_quant_node:
qargs = get_quant_arg_downstream(list(node.users)[0])
qvalue = quantize_value(value, qargs)
dtype = ts.DType.INT8
data = np.full(shape, qvalue, dtype=np.int8)

if output.dtype == ts.DType.INT8:
fill_dtype = np.int8
else:
assert (
output.dtype == ts.DType.FP32
), "'Full' currently only supports FP32 for unquantized models."
dtype = ts.DType.FP32
data = np.full(shape, value, dtype=np.float32)
fill_dtype = np.float32
data = np.full(shape, value, dtype=fill_dtype)

tosa_graph.addConst(shape, dtype, data, node.name + "full-const")
tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const")
tosa_graph.addOperator(
ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name]
)

0 comments on commit 47c2f2e

Please sign in to comment.