Skip to content

Commit

Permalink
Align kernel_sum/ effective_bias useage
Browse files Browse the repository at this point in the history
Use
   kernel_sum[i] = SUM_j(kernel[i][j]*input_offset + input_offset*filter_offset) + bias[i]

for all occurences of kernel_sums in int8 operators, mirroring the behavour of effective_bias
in the LSTM operator. This avoids one extra read/write per output.

Change-Id: If31c3122b7b6bfaa8a1632436d31d20937d6c13d
  • Loading branch information
AdrianLundell committed Sep 16, 2024
1 parent 95f293d commit f2cb41c
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 154 deletions.
6 changes: 4 additions & 2 deletions Include/arm_nnfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nnfunctions.h
* Description: Public header file for CMSIS NN Library
*
* $Date: 19 Aug 2024
* $Revision: V.16.3.0
* $Date: 5 Sep 2024
* $Revision: V.17.0.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand Down Expand Up @@ -1691,6 +1691,7 @@ arm_cmsis_nn_status arm_fully_connected_wrapper_s8(const cmsis_nn_context *ctx,
* @param[in] vector_rows Number of vector rows
* @param[in] vector_data Vector of weigths data
* @param[in] lhs_offset Constant multiplied with each sum
* @param[in] rhs_offset Constant added to each vector element before sum
* @param[in] bias_data Vector of bias data, added to each sum.
* @return The function returns
* <code>ARM_CMSIS_NN_SUCCESS</code> - Successful operation
Expand All @@ -1700,6 +1701,7 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_rows,
const int8_t *vector_data,
const int32_t lhs_offset,
const int32_t rhs_offset,
const int32_t *bias_data);

/**
Expand Down
12 changes: 9 additions & 3 deletions Source/FullyConnectedFunctions/arm_batch_matmul_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nn_batch_matmul_s8.c
* Description: Batch matrix multiplication. Does not perform transposes, see header file for details.
*
* $Date: 19 June 2024
* $Revision: V.1.0.0
* $Date: 5 Sep 2024
* $Revision: V.1.0.1
*
* Target : Arm(R) M-Profile Architecture
*
Expand Down Expand Up @@ -81,7 +81,13 @@ arm_cmsis_nn_status arm_batch_matmul_s8(const cmsis_nn_context *ctx,
{

#if defined(ARM_MATH_MVEI)
arm_vector_sum_s8(vector_sum_buf, rhs_cols, rhs_rows, input_rhs, 1, NULL);
arm_vector_sum_s8(vector_sum_buf,
rhs_cols,
rhs_rows,
input_rhs,
bmm_params->fc_params.input_offset,
bmm_params->fc_params.filter_offset,
NULL);
#endif
for (int i_lhs_rows = 0; i_lhs_rows < lhs_rows; i_lhs_rows++)
{
Expand Down
22 changes: 20 additions & 2 deletions Source/FullyConnectedFunctions/arm_vector_sum_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_vector_sum_s8
* Description: Generic function for calculating vector sums
*
* $Date: 15 February 2024
* $Revision: V.2.0.1
* $Date: 05 Sep 2024
* $Revision: V.3.0.0
*
* Target : Arm(R) M-Profile Architecture
*
Expand Down Expand Up @@ -50,6 +50,7 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_rows,
const int8_t *vector_data,
const int32_t lhs_offset,
const int32_t rhs_offset,
const int32_t *bias_data)
{

Expand Down Expand Up @@ -103,6 +104,15 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
}
vector_data += 5 * vector_cols;

if (rhs_offset)
{
vector_sum_0 += vector_cols * rhs_offset;
vector_sum_1 += vector_cols * rhs_offset;
vector_sum_2 += vector_cols * rhs_offset;
vector_sum_3 += vector_cols * rhs_offset;
vector_sum_4 += vector_cols * rhs_offset;
}

vector_sum_0 *= lhs_offset;
vector_sum_1 *= lhs_offset;
vector_sum_2 *= lhs_offset;
Expand Down Expand Up @@ -132,6 +142,10 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
vector_0 += 16;
}
vector_data += vector_cols;
if (rhs_offset)
{
vector_sum_0 += vector_cols * rhs_offset;
}
vector_sum_0 *= lhs_offset;

vector_sum_buf[i_row_loop_cnt] += vector_sum_0;
Expand All @@ -144,6 +158,10 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
{
sum += *vector_data++;
}
if (rhs_offset)
{
sum += vector_cols * rhs_offset;
}
*vector_sum_buf++ += sum * lhs_offset;
}
#endif
Expand Down
67 changes: 14 additions & 53 deletions Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_per_ch_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nn_vec_mat_mult_t_per_ch_s8
* Description: s8 vector by matrix (transposed) multiplication
*
* $Date: 19 Aug 2024
* $Revision: V.1.0.0
* $Date: 5 Sep 2024
* $Revision: V.1.1.0
*
* Target : Arm(R) M-Profile Architecture
*
Expand Down Expand Up @@ -74,15 +74,17 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
if (rhs_offset)
{
#if defined(ARM_MATH_MVEI)
(void)bias;
(void)lhs_offset;
const int32_t row_loop_cnt = rhs_rows / 4;
const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};

for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
int32_t acc_0 = 0;
int32_t acc_1 = 0;
int32_t acc_2 = 0;
int32_t acc_3 = 0;
int32_t acc_0 = *kernel_sum++;
int32_t acc_1 = *kernel_sum++;
int32_t acc_2 = *kernel_sum++;
int32_t acc_3 = *kernel_sum++;

const int32_t col_loop_cnt = (rhs_cols + 15) / 16;

Expand All @@ -94,14 +96,6 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,

int32_t lhs_sum = 0;

if (bias)
{
acc_0 = *bias++;
acc_1 = *bias++;
acc_2 = *bias++;
acc_3 = *bias++;
}

uint32_t col_cnt = (uint32_t)rhs_cols;

for (int32_t i = 0; i < col_loop_cnt; i++)
Expand Down Expand Up @@ -134,12 +128,7 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,

int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};

const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
acc += vdupq_n_s32(lhs_offset) * rhs_sum;
kernel_sum += 4;

acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
acc += vdupq_n_s32(rhs_offset * lhs_offset) * vdupq_n_s32(rhs_cols);

acc = arm_requantize_mve(acc, *dst_multiplier++, *dst_shift++);

Expand All @@ -155,7 +144,7 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
const int loop_cnt = rhs_rows % 4;
for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
{
int32_t acc_0 = 0;
int32_t acc_0 = *kernel_sum++;
const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
const int8_t *lhs_vec = lhs;
const int8_t *rhs_ptr = rhs;
Expand All @@ -177,15 +166,7 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
}
rhs += rhs_cols;

if (bias)
{
acc_0 += *bias;
bias++;
}
const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
acc_0 += rhs_sum * lhs_offset;
acc_0 += lhs_sum * rhs_offset;
acc_0 += rhs_cols * lhs_offset * rhs_offset;

acc_0 = arm_nn_requantize(acc_0, *dst_multiplier++, *dst_shift++);
acc_0 += dst_offset;
Expand Down Expand Up @@ -430,10 +411,10 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,

for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
int32_t acc_0 = 0;
int32_t acc_1 = 0;
int32_t acc_2 = 0;
int32_t acc_3 = 0;
int32_t acc_0 = *kernel_sum++;
int32_t acc_1 = *kernel_sum++;
int32_t acc_2 = *kernel_sum++;
int32_t acc_3 = *kernel_sum++;

const int32_t col_loop_cnt = (rhs_cols + 15) / 16;

Expand All @@ -443,14 +424,6 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;

if (bias)
{
acc_0 = *bias++;
acc_1 = *bias++;
acc_2 = *bias++;
acc_3 = *bias++;
}

uint32_t col_cnt = (uint32_t)rhs_cols;

for (int32_t i = 0; i < col_loop_cnt; i++)
Expand Down Expand Up @@ -482,10 +455,6 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,

int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};

const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
acc += vdupq_n_s32(lhs_offset) * rhs_sum;
kernel_sum += 4;

acc = arm_requantize_mve_32x4(acc, vldrwq_s32(dst_multiplier), vldrwq_s32(dst_shift));
dst_multiplier += 4;
dst_shift += 4;
Expand All @@ -502,7 +471,7 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
const int loop_cnt = rhs_rows % 4;
for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
{
int32_t acc_0 = 0;
int32_t acc_0 = *kernel_sum++;
const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
const int8_t *lhs_vec = lhs;
const int8_t *rhs_ptr = rhs;
Expand All @@ -522,14 +491,6 @@ arm_cmsis_nn_status arm_nn_vec_mat_mult_t_per_ch_s8(const int8_t *lhs,
}
rhs += rhs_cols;

if (bias)
{
acc_0 += *bias;
bias++;
}
const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
const int32_t offsets = rhs_sum * lhs_offset;
acc_0 += offsets;
acc_0 = arm_nn_requantize(acc_0, *dst_multiplier++, *dst_shift++);

acc_0 += dst_offset;
Expand Down
Loading

0 comments on commit f2cb41c

Please sign in to comment.