diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 5f0d847e84..4c43172a92 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -17,7 +17,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, - ReluPattern, + ReluPattern0, + ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( create_zero_bias_int32, @@ -36,6 +37,9 @@ # pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`. ArgsType = Any +# Use this part for patterns with multiple aten ops +ReluPatterns = (ReluPattern0, ReluPattern1) + # Helper function to get the args and kwargs for the linear replacement op def get_args_and_kwargs_linear( @@ -411,7 +415,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs, quant_node, ) - elif isinstance(pattern, ReluPattern): + elif isinstance(pattern, ReluPatterns): args, kwargs = get_args_and_kwargs_relu( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 943b9e473a..7043bae571 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -288,9 +288,11 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default -class ReluPattern(QuantizationPattern): +# This is a base class for ReLU, since it can be used with two different aten ops +class ReluBasePattern(QuantizationPattern): + @abstractmethod def partition_types(self) -> List[OpOverload]: - return [torch.ops.aten.relu.default] + pass def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -308,3 +310,15 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_relu.default + + +# Regular relu op +class ReluPattern0(ReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu.default] + + +# Alternate relu op +class ReluPattern1(ReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu_.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 5a2c101512..4cd3c6bfb4 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -18,7 +18,8 @@ LinearPattern, MatmulPattern, QuantizationPattern, - ReluPattern, + ReluPattern0, + ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( find_sequential_partitions_aten, @@ -159,6 +160,7 @@ def __init__(self) -> None: CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), CadenceAtenQuantizer(LinearPattern(), static_qconfig), CadenceAtenQuantizer(MatmulPattern(), static_qconfig), - CadenceAtenQuantizer(ReluPattern(), static_qconfig), + CadenceAtenQuantizer(ReluPattern0(), static_qconfig), + CadenceAtenQuantizer(ReluPattern1(), static_qconfig), ] )