diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index afcf8e5aa9..ca8a44f00c 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -4,20 +4,34 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Dict, Tuple + import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from torch._subclasses import FakeTensor from torch.utils._pytree import tree_map_only +# pyre-strict + +# Similar to what's done in executorch/exir/pass_base.py +Argument = Any # pyre-ignore + + class ReplacePT2QuantWithCadenceQuantPass(ExportPass): """ Replace the pt2 quantization ops with custom cadence quantization ops. """ - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: return super().call_operator(op, args, kwargs, meta) @@ -34,7 +48,13 @@ class ReplacePT2DequantWithCadenceDequantPass(ExportPass): Replace the pt2 dequantization ops with custom cadence dequantization ops. """ - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: return super().call_operator(op, args, kwargs, meta) @@ -51,7 +71,13 @@ class ReplaceScalarTensorWithFullPass(ExportPass): aten.scalar_tensor can be replaced by aten.full with a shape of [1]. """ - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: if op not in { exir_ops.edge.aten.scalar_tensor.default, torch.ops.aten.scalar_tensor.default, @@ -64,7 +90,7 @@ def call_operator(self, op, args, kwargs, meta): [1], args[0], ), - {}, + {"dtype": torch.float32}, meta, ) @@ -75,7 +101,13 @@ class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): view_copy op """ - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, # which allows us to cover all overloads. if get_edge_overload_packet(op) not in { @@ -99,7 +131,13 @@ def call_operator(self, op, args, kwargs, meta): class RemoveZeroSizedCatArgsPass(ExportPass): - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: if op != exir_ops.edge.aten.cat.default: return super().call_operator(op, args, kwargs, meta) @@ -122,6 +160,7 @@ def call_operator(self, op, args, kwargs, meta): # TODO(matthiascremon): confirm this is the best way to do this. if isinstance(result, FakeTensor): result.constant = result + # pyre-ignore[7]: Incompatible return type. return torch.empty_like(result) # If there was only one tensor in the new_args list, @@ -130,7 +169,7 @@ def call_operator(self, op, args, kwargs, meta): return new_args[0] # Otherwise, we replace args[0] with new_args. - args = list(args) - args[0] = new_args + init_args = list(args) + init_args[0] = new_args args = tuple(args) return super().call_operator(op, args, kwargs, meta) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 52e1d148cf..0a1927e725 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Any, Dict, List, Tuple import torch @@ -31,6 +33,11 @@ from torch.fx.passes.utils.fuser_utils import legalize_graph +# Use this to avoid pyre errors +# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`. +ArgsType = Any + + # Helper function to get the args and kwargs for the linear replacement op def get_args_and_kwargs_linear( graph_module: GraphModule, @@ -40,7 +47,7 @@ def get_args_and_kwargs_linear( dequants_weights: List[fx.Node], bias_inputs: List[fx.Node], quant_node: fx.Node, -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: """ Returns the args and kwargs for the linear replacement op. """ @@ -98,7 +105,7 @@ def get_args_and_kwargs_layer_norm( dequants_inputs: List[fx.Node], other_inputs: List[fx.Node], quant_node: fx.Node, -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: """ Returns the args and kwargs for the layer norm replacement op. """ @@ -167,7 +174,7 @@ def get_args_and_kwargs_matmul( inputs_inputs: List[fx.Node], dequants_inputs: List[fx.Node], quant_node: fx.Node, -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: requantize_scale = ( # pyre-ignore[58]: Unsupported operand dequants_inputs[0].args[1] @@ -203,7 +210,7 @@ def get_args_and_kwargs_conv( bias_inputs: List[fx.Node], quant_node: fx.Node, op_node: fx.Node, -): +) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: weight_scale = dequants_weights[0].args[1] weight_zero_point = dequants_weights[0].args[2] # pyre-fixme[58]: Unsupported operand types @@ -277,12 +284,14 @@ def get_args_and_kwargs_relu( graph_module: GraphModule, inputs_inputs: List[fx.Node], dequants_inputs: List[fx.Node], -): +) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: # Make the args and kwargs for the replacement op args = tuple(inputs_inputs) X_zero_point = graph_module.graph.call_function( - torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2]) + torch.ops.aten.full.default, + ([1], dequants_inputs[0].args[2]), + {"dtype": torch.int32}, ) kwargs = { @@ -292,8 +301,10 @@ def get_args_and_kwargs_relu( class QuantFusion(ExportPass): - def __init__(self, patterns): + # pyre-ignore[2]: Parameter `patterns` has no type specified + def __init__(self, patterns) -> None: super().__init__() + # pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified self.patterns = patterns def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 @@ -427,10 +438,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 graph_module.recompile() @classmethod + # pyre-ignore[2]: Parameter `nodes` has no type specified def is_fused(cls, nodes) -> bool: return any(cls.__qualname__ in n.meta for n in nodes) @classmethod + # pyre-ignore[2]: Parameter `nodes` has no type specified def mark_fused(cls, nodes) -> bool: for n in nodes: # pyre-fixme[7]: Incompatible return type