diff --git a/backends/cadence/hifi/operators/quantized_linear_out.cpp b/backends/cadence/hifi/operators/quantized_linear_out.cpp index fb186abbb1..8e872fd708 100644 --- a/backends/cadence/hifi/operators/quantized_linear_out.cpp +++ b/backends/cadence/hifi/operators/quantized_linear_out.cpp @@ -38,31 +38,54 @@ void quantized_linear_out( int64_t out_dim = weight.size(0); // = out_dim int64_t in_dim = weight.size(1); // = in_dim - const uint8_t* __restrict__ in_data = src.const_data_ptr(); - const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); - const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + if (src.scalar_type() == exec_aten::ScalarType::Byte) { + const uint8_t* __restrict__ in_data = src.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); - // The nnlib kernel to compute quantized linear via matmul. - int32_t ret = impl::HiFi::kernels::matmul_asym8uxasym8u_asym8u( - out_data, // p_out - weight_data, // p_mat1, - in_data, // p_mat2, - bias_data, // p_bias - out_dim, // rows of p_mat1 - in_dim, // cols of p_mat1 - in_dim, // row_stride of p_mat1 - leading_dims, // vec_count, i.e., rows of p_mat2 - in_dim, // vec_offset of p_mat2. - out_dim, // out_offset, i.e., offset of next output element written - 1, // out_stride, i.e., stride to go to next output row - -weight_zero_point.const_data_ptr()[0], // mat1_zero_bias - -src_zero_point, // mat2_zero_bias - out_multiplier.const_data_ptr(), // out_multiplier - out_shift.const_data_ptr(), // out_shift - out_zero_point, // out_zero_bias - false); // per channel quantization - ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear failed"); + // The nnlib kernel to compute quantized linear via matmul. + xa_nn_matmul_asym8uxasym8u_asym8u( + out_data, + weight_data, + in_data, + bias_data, + out_dim, + in_dim, + in_dim, + leading_dims, + in_dim, + out_dim, + 1, + -weight_zero_point.const_data_ptr()[0], + -src_zero_point, + out_multiplier.const_data_ptr()[0], + out_shift.const_data_ptr()[0], + out_zero_point); + } else { + const int8_t* __restrict__ in_data = src.const_data_ptr(); + const int8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + xa_nn_matmul_asym8sxasym8s_asym8s( + out_data, + weight_data, + in_data, + bias_data, + out_dim, + in_dim, + in_dim, + leading_dims, + in_dim, + out_dim, + 1, + -weight_zero_point.const_data_ptr()[0], + -src_zero_point, + out_multiplier.const_data_ptr()[0], + out_shift.const_data_ptr()[0], + out_zero_point); + } } }; // namespace native