diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 8c0c790c58..1dcfcbcd8d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -146,7 +146,10 @@ "quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" ) - +lib.define( + "quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) # ------------------------------------ # # Migrated from custom_ops.ymal # @@ -192,6 +195,10 @@ "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" @@ -595,6 +602,28 @@ def quantized_fully_connected_meta( bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + +@register_fake("cadence::quantized_fully_connected.per_tensor") +def quantized_fully_connected_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: int, out_multiplier: int, out_shift: int, out_zero_point: int, @@ -607,7 +636,7 @@ def quantized_fully_connected_meta( weight_size = list(weight.size()) assert len(weight_size) == 2 out_size[-1] = weight_size[0] - return src.new_empty(out_size, dtype=torch.uint8) + return src.new_empty(out_size, dtype=src.dtype) @register_fake("cadence::convolution") diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index fd51385bcd..e42d9f2d1a 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -9,6 +9,8 @@ # 3. functions that replace an ATen op with another semantically equivalent ATen op. # 4. functions that concretize optional args. +# pyre-unsafe + import math from operator import neg from typing import cast, Dict, Iterable, Sequence, Set, Tuple @@ -1698,12 +1700,6 @@ def call_operator(self, op, args, kwargs, meta): if leading_dims != 1: return super().call_operator(op, args, kwargs, meta) - # If the op is quantized::linear, but per-channel quantized, bail. - if op == exir_ops.edge.cadence.quantized_linear.default: - weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1] - if weight.shape != [1]: - return super().call_operator(op, args, kwargs, meta) - # Replace the linear with fully connected op return super().call_operator( self.linear_to_fc_op[op], @@ -1893,6 +1889,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): exir_ops.edge.cadence.quantized_conv.per_tensor, [8, 9, 12, 13], ), + exir_ops.edge.cadence.quantized_fully_connected: ( + exir_ops.edge.cadence.quantized_fully_connected.per_tensor, + [4, 5, 6], + ), exir_ops.edge.cadence.quantized_layer_norm: ( exir_ops.edge.cadence.quantized_layer_norm.per_tensor, [1, 2],