Skip to content

Commit

Permalink
Enable aten.relu_.default in the CadenceQuantizer (pytorch#4344)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4344

As titled. Some model use `torch.ops.aten.relu_.default` instead of `torch.ops.aten.relu.default`. Enable that in the quantizer.

Reviewed By: zonglinpengmeta

Differential Revision: D60071019

fbshipit-source-id: efad4818f17ca1aef7445d4f8d651bd9f1c46444
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 23, 2024
1 parent 3154afc commit 48da61a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
8 changes: 6 additions & 2 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
LayerNormPattern,
LinearPattern,
MatmulPattern,
ReluPattern,
ReluPattern0,
ReluPattern1,
)
from executorch.backends.cadence.aot.quantizer.utils import (
create_zero_bias_int32,
Expand All @@ -36,6 +37,9 @@
# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`.
ArgsType = Any

# Use this part for patterns with multiple aten ops
ReluPatterns = (ReluPattern0, ReluPattern1)


# Helper function to get the args and kwargs for the linear replacement op
def get_args_and_kwargs_linear(
Expand Down Expand Up @@ -411,7 +415,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs,
quant_node,
)
elif isinstance(pattern, ReluPattern):
elif isinstance(pattern, ReluPatterns):
args, kwargs = get_args_and_kwargs_relu(
graph_module,
inputs_inputs,
Expand Down
18 changes: 16 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,11 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default


class ReluPattern(QuantizationPattern):
# This is a base class for ReLU, since it can be used with two different aten ops
class ReluBasePattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -308,3 +310,15 @@ def get_anchors(

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default


# Regular relu op
class ReluPattern0(ReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]


# Alternate relu op
class ReluPattern1(ReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu_.default]
6 changes: 4 additions & 2 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
LinearPattern,
MatmulPattern,
QuantizationPattern,
ReluPattern,
ReluPattern0,
ReluPattern1,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self) -> None:
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern0(), static_qconfig),
CadenceAtenQuantizer(ReluPattern1(), static_qconfig),
]
)

0 comments on commit 48da61a

Please sign in to comment.