From 9eacdff4893cc4f5f772bef9c1476a9c2785ab45 Mon Sep 17 00:00:00 2001 From: Ryan OShea <86965113+ArmRyan@users.noreply.github.com> Date: Mon, 12 Feb 2024 07:58:23 +0000 Subject: [PATCH] Add dsp and mve support to transpose conv int8 (#103) * Adds new support functions to read and pad 2 int8s * Adds new support functions to allow addition to read and pad * Adds dsp optimizations for arm_nn_mat_mult_nt_t_s8_s32 * Adds mve optimizations for arm_nn_mat_mult_nt_t_s8_s32 * Adds mve requantization to arm_transpose_conv_s8 * Adds new unit test Signed-off-by: Ryan O'Shea --- ARM.CMSIS-NN.pdsc | 1 + Include/arm_nnsupportfunctions.h | 42 +- README.md | 4 +- .../arm_transpose_conv_s8.c | 34 +- .../arm_nn_mat_mult_nt_t_s8_s32.c | 369 ++++++++++++++++-- .../TestData/transpose_conv_4/biases_data.h | 6 + .../TestData/transpose_conv_4/config_data.h | 26 ++ .../TestData/transpose_conv_4/input_data.h | 42 ++ .../transpose_conv_4/output_mult_data.h | 6 + .../transpose_conv_4/output_ref_data.h | 24 ++ .../transpose_conv_4/output_shift_data.h | 6 + .../TestData/transpose_conv_4/test_data.h | 9 + .../TestData/transpose_conv_4/weights_data.h | 82 ++++ .../Unity/unity_test_arm_transpose_conv_s8.c | 3 +- .../test_arm_transpose_conv_s8.c | 88 ++++- Tests/UnitTest/generate_test_data.py | 19 + 16 files changed, 719 insertions(+), 42 deletions(-) create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/biases_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/config_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/input_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_mult_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_ref_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_shift_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/test_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/transpose_conv_4/weights_data.h diff --git a/ARM.CMSIS-NN.pdsc b/ARM.CMSIS-NN.pdsc index 03d3783d..d5240368 100644 --- a/ARM.CMSIS-NN.pdsc +++ b/ARM.CMSIS-NN.pdsc @@ -67,6 +67,7 @@ + diff --git a/Include/arm_nnsupportfunctions.h b/Include/arm_nnsupportfunctions.h index f2393db0..5da1cb67 100644 --- a/Include/arm_nnsupportfunctions.h +++ b/Include/arm_nnsupportfunctions.h @@ -21,8 +21,8 @@ * Title: arm_nnsupportfunctions.h * Description: Public header file of support functions for CMSIS NN Library * - * $Date: 19 January 2024 - * $Revision: V.18.0.0 + * $Date: 31 January 2024 + * $Revision: V.18.1.0 * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ @@ -920,6 +920,44 @@ __STATIC_FORCEINLINE const int8_t *read_and_pad(const int8_t *source, int32_t *o return source; } +/** + * @brief read and expand one s8 word into two s16 words with ordering and addition. + */ +__STATIC_FORCEINLINE void read_pad_and_add_s8(const int8_t *source, int32_t *out1, int32_t *out2, const uint32_t add) +{ + int32_t inA = arm_nn_read_s8x4(source); + int32_t inAbuf1 = SXTAB16_RORn(add, (uint32_t)inA, 8); + int32_t inAbuf2 = SXTAB16(add, inA); + + #ifndef ARM_MATH_BIG_ENDIAN + *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16)); + *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16)); + #else + *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16)); + *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16)); + #endif +} + +/** + * @brief read and expand two bytes into one word with ordering. + */ +__STATIC_FORCEINLINE void read_and_pad_s8x2(const int8_t *source, int32_t *out) +{ + int16_t in = arm_nn_read_s8x2(source); + int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8); + *out = SXTB16(inA); +} + +/** + * @brief read and expand two bytes into one word with ordering and addition. + */ +__STATIC_FORCEINLINE void read_pad_and_add_s8x2(const int8_t *source, int32_t *out, const uint32_t add) +{ + int16_t in = arm_nn_read_s8x2(source); + int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8); + *out = SXTAB16(add, inA); +} + /** * @brief read and expand one s8 word into two s16 words with no additional ordering. */ diff --git a/README.md b/README.md index 8e0e985f..51840024 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,10 @@ Processors with Arm Helium Technology use the Arm M-profile Vector Extension(MVE Examples are Cortex-M55 or Cortex-M85 configured with MVE. | Operator | C
int8 | C
int16 | C
int4* | DSP
int8 | DSP
int16 | DSP
int4* | MVE
int8 | MVE
int16 | -| --------------- | ----------- | ---------- |------------| ------------| -------------|--------------| ------------| -------------| +| --------------- | ----------- | ---------- |------------|-------------| -------------|--------------|-------------| -------------| | Conv2D | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | DepthwiseConv2D | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| TransposeConv2D | Yes | No | No | No | No | No | No | No | +| TransposeConv2D | Yes | No | No | Yes | No | No | Yes | No | | Fully Connected | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | Add | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes | | Mul | Yes | Yes | N/A | Yes | Yes | N/A | Yes | Yes | diff --git a/Source/ConvolutionFunctions/arm_transpose_conv_s8.c b/Source/ConvolutionFunctions/arm_transpose_conv_s8.c index e866c6fc..7a5f3660 100644 --- a/Source/ConvolutionFunctions/arm_transpose_conv_s8.c +++ b/Source/ConvolutionFunctions/arm_transpose_conv_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -19,10 +19,10 @@ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library * Title: arm_transpose_conv_s8.c - * Description: s8 version of convolution using symmetric quantization. + * Description: s8 version of transpose convolution using symmetric quantization. * - * $Date: 5 October 2023 - * $Revision: V.1.0.0 + * $Date: 31 January 2024 + * $Revision: V.1.1.0 * * Target : Arm(R) M-Profile Architecture * @@ -172,11 +172,30 @@ arm_cmsis_nn_status arm_transpose_conv_s8(const cmsis_nn_context *ctx, } } } - img_data = img_buf_ptr; for (int i = 0; i < output_x * output_y; i++) { - for (int i_output_ch = 0; i_output_ch < output_ch; i_output_ch++) +#if defined(ARM_MATH_MVEI) + int output_ch_idx = 0; + int8_t *ip_out_data = output_data_ptr; + for (int32_t i_channel_rmdr = output_ch; i_channel_rmdr > 0; i_channel_rmdr -= 4) + { + mve_pred16_t p = vctp32q((uint32_t)i_channel_rmdr); + int32x4_t result = vldrwq_z_s32(&img_data[output_ch_idx], p); + result = arm_requantize_mve_32x4(result, + vldrwq_z_s32(&output_multiplier[output_ch_idx], p), + vldrwq_z_s32(&output_shift[output_ch_idx], p)); + result = vaddq_n_s32(result, out_offset); + result = vmaxq_s32(result, vdupq_n_s32(activation_min)); + result = vminq_s32(result, vdupq_n_s32(activation_max)); + vstrbq_p_s32(ip_out_data, result, p); + ip_out_data += 4; + output_ch_idx += 4; + } + output_data_ptr += output_ch; +#else + int i_output_ch = 0; + for (; i_output_ch < output_ch; i_output_ch++) { int32_t result = arm_nn_requantize(img_data[i_output_ch], output_multiplier[i_output_ch], output_shift[i_output_ch]); @@ -185,13 +204,12 @@ arm_cmsis_nn_status arm_transpose_conv_s8(const cmsis_nn_context *ctx, result = MIN(result, activation_max); *output_data_ptr++ = (int8_t)result; } +#endif img_data += output_ch; } - input_data_ptr += (input_size * input_ch); batch_cnt--; } - /* Return to application */ return ARM_CMSIS_NN_SUCCESS; } diff --git a/Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8_s32.c b/Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8_s32.c index fea8a292..a140dc6f 100644 --- a/Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8_s32.c +++ b/Source/NNSupportFunctions/arm_nn_mat_mult_nt_t_s8_s32.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -18,11 +18,11 @@ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library - * Title: arm_nn_mat_mult_s8_nt_t_s8_s32 + * Title: arm_nn_mat_mult_nt_t_s8_s32 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed * - * $Date: 5 October 2023 - * $Revision: V.1.0.0 + * $Date: 31 January 2024 + * $Revision: V.1.1.0 * * Target : Arm(R) M-Profile Architecture * @@ -54,9 +54,332 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, const int32_t lhs_offset, const int32_t dst_idx_offset) { + int32_t rhs_rows_idx = rhs_rows; const int32_t dst_idx_col_offset = dst_idx_offset * rhs_cols; +#if defined(ARM_MATH_MVEI) + for (; rhs_rows_idx >= 16; rhs_rows_idx -= 16) + { + int32_t *dst_ptr = &dst[0]; + const int8_t *lhs_ptr = &lhs[0]; + int32_t lhs_rows_idx = lhs_rows; + + for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4) + { + const int8_t *rhs_ptr = &rhs[0]; + int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs1 = vldrbq_s8(lhs_ptr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs2 = vldrbq_s8(lhs_ptr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs3 = vldrbq_s8(lhs_ptr); + lhs_ptr += rhs_rows; + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t *ip_dst = dst_ptr; + + int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr); + int32_t rhs_sum = vaddvq_s8(v_rhs0); + rhs_sum *= lhs_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_s8(*ip_dst, v_lhs0, v_rhs0); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_s8(*ip_dst, v_lhs1, v_rhs0); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_s8(*ip_dst, v_lhs2, v_rhs0); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_s8(*ip_dst, v_lhs3, v_rhs0); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + + dst_ptr += 3 * dst_idx_col_offset; + } + for (; lhs_rows_idx > 0; lhs_rows_idx--) + { + const int8_t *rhs_ptr = &rhs[0]; + int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr); + + int32_t offset_sum = vaddvq_s8(v_rhs0); + *dst_ptr += offset_sum * lhs_offset; + + *dst_ptr = vmladavaq_s8(*dst_ptr, v_lhs0, v_rhs0); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + lhs_ptr += rhs_rows; + } + + rhs += 16; + lhs += 16; + } + if (rhs_rows_idx) + { + mve_pred16_t rmdr = (1 << rhs_rows_idx) - 1; + int32_t *dst_ptr = &dst[0]; + const int8_t *lhs_ptr = &lhs[0]; + int32_t lhs_rows_idx = lhs_rows; + + for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4) + { + const int8_t *rhs_ptr = &rhs[0]; + int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs1 = vldrbq_z_s8(lhs_ptr, rmdr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs2 = vldrbq_z_s8(lhs_ptr, rmdr); + lhs_ptr += rhs_rows; + int8x16_t v_lhs3 = vldrbq_z_s8(lhs_ptr, rmdr); + lhs_ptr += rhs_rows; + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t *ip_dst = dst_ptr; + int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr); + + int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr); + rhs_sum *= lhs_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs0, v_rhs0, rmdr); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs1, v_rhs0, rmdr); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs2, v_rhs0, rmdr); + ip_dst += dst_idx_col_offset; + + *ip_dst += rhs_sum; + *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs3, v_rhs0, rmdr); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + + dst_ptr += 3 * dst_idx_col_offset; + } + for (; lhs_rows_idx > 0; lhs_rows_idx--) + { + const int8_t *rhs_ptr = &rhs[0]; + int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr); + + int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr); + *dst_ptr += rhs_sum * lhs_offset; + + *dst_ptr = vmladavaq_p_s8(*dst_ptr, v_lhs0, v_rhs0, rmdr); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + lhs_ptr += rhs_rows; + } + } + +#elif defined(ARM_MATH_DSP) + int16_t lhs_offset_s16 = (int16_t)lhs_offset; + const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16); + for (; rhs_rows_idx >= 8; rhs_rows_idx -= 8) + { + int32_t *dst_ptr = &dst[0]; + const int8_t *lhs_ptr = &lhs[0]; + int32_t lhs_rows_idx = lhs_rows >> 1; + + while (lhs_rows_idx) + { + const int8_t *rhs_ptr = &rhs[0]; + + int32_t lhs000, lhs001, lhs010, lhs011, lhs100, lhs101, lhs110, lhs111; + read_pad_and_add_s8(lhs_ptr, &lhs000, &lhs001, lhs_offset_s16x2); + read_pad_and_add_s8(&lhs_ptr[4], &lhs010, &lhs011, lhs_offset_s16x2); + read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs100, &lhs101, lhs_offset_s16x2); + read_pad_and_add_s8(&lhs_ptr[rhs_rows + 4], &lhs110, &lhs111, lhs_offset_s16x2); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t rhs_val00, rhs_val01; + read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01); + + dst_ptr[0] = SMLAD(lhs000, rhs_val00, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs001, rhs_val01, dst_ptr[0]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs100, rhs_val00, dst_ptr[dst_idx_col_offset]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs101, rhs_val01, dst_ptr[dst_idx_col_offset]); + + read_and_pad(&rhs_ptr[4], &rhs_val00, &rhs_val01); + + dst_ptr[0] = SMLAD(lhs010, rhs_val00, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs011, rhs_val01, dst_ptr[0]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs110, rhs_val00, dst_ptr[dst_idx_col_offset]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs111, rhs_val01, dst_ptr[dst_idx_col_offset]); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + dst_ptr += dst_idx_col_offset; + + lhs_ptr += rhs_rows << 1; + + lhs_rows_idx--; + } + // Left-over rows + if (lhs_rows % 2) + { + const int8_t *rhs_ptr = &rhs[0]; + int32_t lhs00, lhs01, lhs10, lhs11; + read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2); + read_pad_and_add_s8(&lhs_ptr[4], &lhs10, &lhs11, lhs_offset_s16x2); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t rhs_val00, rhs_val01, rhs_val10, rhs_val11; + read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01); + read_and_pad(&rhs_ptr[4], &rhs_val10, &rhs_val11); + + dst_ptr[0] = SMLAD(lhs00, rhs_val00, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs01, rhs_val01, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs10, rhs_val10, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs11, rhs_val11, dst_ptr[0]); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + } + + rhs += 8; + lhs += 8; + } + for (; rhs_rows_idx >= 4; rhs_rows_idx -= 4) + { + int32_t *dst_ptr = &dst[0]; + const int8_t *lhs_ptr = &lhs[0]; + + int32_t lhs_rows_idx = lhs_rows >> 1; + + while (lhs_rows_idx) + { + const int8_t *rhs_ptr = &rhs[0]; + + int32_t lhs00, lhs01, lhs10, lhs11; + read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2); + read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs10, &lhs11, lhs_offset_s16x2); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t rhs_val0, rhs_val1; + read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1); + + dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs10, rhs_val0, dst_ptr[dst_idx_col_offset]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs11, rhs_val1, dst_ptr[dst_idx_col_offset]); + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + dst_ptr += dst_idx_col_offset; + + lhs_ptr += rhs_rows << 1; + + lhs_rows_idx--; + } + // Left-over rows + if (lhs_rows % 2) + { + const int8_t *rhs_ptr = &rhs[0]; + int32_t lhs00, lhs01; + read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t rhs_val0, rhs_val1; + read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1); + + dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]); + dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]); - for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2) + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + } + + rhs += 4; + lhs += 4; + } + for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2) + { + int32_t *dst_ptr = &dst[0]; + const int8_t *lhs_ptr = &lhs[0]; + + int32_t lhs_rows_idx = lhs_rows >> 1; + + while (lhs_rows_idx) + { + const int8_t *rhs_ptr = &rhs[0]; + + int32_t lhs0, lhs1; + read_pad_and_add_s8x2(lhs_ptr, &lhs0, lhs_offset_s16x2); + read_pad_and_add_s8x2(&lhs_ptr[rhs_rows], &lhs1, lhs_offset_s16x2); + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + int32_t rhs_val; + read_and_pad_s8x2(rhs_ptr, &rhs_val); + + dst_ptr[0] = SMLAD(lhs0, rhs_val, dst_ptr[0]); + dst_ptr[dst_idx_col_offset] = SMLAD(lhs1, rhs_val, dst_ptr[dst_idx_col_offset]); + + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + dst_ptr += dst_idx_col_offset; + + lhs_ptr += rhs_rows << 1; + + lhs_rows_idx--; + } + // Left-over rows + if (lhs_rows % 2) + { + const int8_t *rhs_ptr = &rhs[0]; + const int32_t lhs_value = lhs_ptr[0] + lhs_offset; + const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset; + + for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) + { + const int32_t rhs_value0 = rhs_ptr[0]; + const int32_t rhs_value01 = rhs_ptr[1]; + + dst_ptr[0] += lhs_value * rhs_value0; + dst_ptr[0] += lhs_value01 * rhs_value01; + dst_ptr += dst_idx_offset; + rhs_ptr += rhs_rows; + } + } + + rhs += 2; + lhs += 2; + } +#else + for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2) { int32_t *dst_ptr = &dst[0]; const int8_t *lhs_ptr = &lhs[0]; @@ -78,16 +401,11 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, const int32_t rhs_value0 = rhs_ptr[0]; const int32_t rhs_value1 = rhs_ptr[1]; - const int32_t res00 = lhs_value00 * rhs_value0; - const int32_t res10 = lhs_value01 * rhs_value1; - - const int32_t res01 = lhs_value10 * rhs_value0; - const int32_t res11 = lhs_value11 * rhs_value1; + dst_ptr[0] += lhs_value00 * rhs_value0; + dst_ptr[0] += lhs_value01 * rhs_value1; - dst_ptr[0] += res00; - dst_ptr[0] += res10; - dst_ptr[dst_idx_col_offset] += res01; - dst_ptr[dst_idx_col_offset] += res11; + dst_ptr[dst_idx_col_offset] += lhs_value10 * rhs_value0; + dst_ptr[dst_idx_col_offset] += lhs_value11 * rhs_value1; dst_ptr += dst_idx_offset; rhs_ptr += rhs_rows; } @@ -97,7 +415,6 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, lhs_rows_idx--; } - // Left-over rows if (lhs_rows % 2) { @@ -109,12 +426,9 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, { const int32_t rhs_value0 = rhs_ptr[0]; const int32_t rhs_value01 = rhs_ptr[1]; - const int32_t res00 = lhs_value * rhs_value0; - const int32_t res01 = lhs_value01 * rhs_value01; - - dst_ptr[0] += res00; - dst_ptr[0] += res01; + dst_ptr[0] += lhs_value * rhs_value0; + dst_ptr[0] += lhs_value01 * rhs_value01; dst_ptr += dst_idx_offset; rhs_ptr += rhs_rows; } @@ -123,8 +437,9 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, rhs += 2; lhs += 2; } - - if (rhs_rows % 2) +#endif +#if !defined(ARM_MATH_MVEI) + if (rhs_rows_idx) { const int8_t *lhs_ptr = &lhs[0]; int32_t *dst_ptr = &dst[0]; @@ -132,15 +447,13 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx) { const int8_t *rhs_ptr = &rhs[0]; - int32_t lhs_value = lhs_ptr[0] + lhs_offset; + const int32_t lhs_value = lhs_ptr[0] + lhs_offset; for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--) { - int32_t rhs_value = rhs_ptr[0]; - - int32_t res00 = lhs_value * rhs_value; + const int32_t rhs_value = rhs_ptr[0]; - *dst_ptr += res00; + *dst_ptr += lhs_value * rhs_value; dst_ptr += dst_idx_offset; rhs_ptr += rhs_rows; @@ -148,7 +461,7 @@ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs, lhs_ptr += rhs_rows; } } - +#endif return ARM_CMSIS_NN_SUCCESS; } diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/biases_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/biases_data.h new file mode 100644 index 00000000..bed65a6e --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/biases_data.h @@ -0,0 +1,6 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int32_t *transpose_conv_4_biases = NULL; diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/config_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/config_data.h new file mode 100644 index 00000000..183d36c7 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/config_data.h @@ -0,0 +1,26 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#define TRANSPOSE_CONV_4_OUT_CH 5 +#define TRANSPOSE_CONV_4_IN_CH 32 +#define TRANSPOSE_CONV_4_INPUT_W 1 +#define TRANSPOSE_CONV_4_INPUT_H 7 +#define TRANSPOSE_CONV_4_DST_SIZE 405 +#define TRANSPOSE_CONV_4_INPUT_SIZE 224 +#define TRANSPOSE_CONV_4_OUT_ACTIVATION_MIN -128 +#define TRANSPOSE_CONV_4_OUT_ACTIVATION_MAX 127 +#define TRANSPOSE_CONV_4_INPUT_BATCHES 3 +#define TRANSPOSE_CONV_4_FILTER_X 3 +#define TRANSPOSE_CONV_4_FILTER_Y 3 +#define TRANSPOSE_CONV_4_STRIDE_X 3 +#define TRANSPOSE_CONV_4_STRIDE_Y 1 +#define TRANSPOSE_CONV_4_PAD_X 0 +#define TRANSPOSE_CONV_4_PAD_Y 0 +#define TRANSPOSE_CONV_4_OUTPUT_W 3 +#define TRANSPOSE_CONV_4_OUTPUT_H 9 +#define TRANSPOSE_CONV_4_INPUT_OFFSET 128 +#define TRANSPOSE_CONV_4_OUTPUT_OFFSET 9 +#define TRANSPOSE_CONV_4_DILATION_X 1 +#define TRANSPOSE_CONV_4_DILATION_Y 1 +#define TRANSPOSE_CONV_4_PAD_X_WITH_OFFSET 0 +#define TRANSPOSE_CONV_4_PAD_Y_WITH_OFFSET 0 diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/input_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/input_data.h new file mode 100644 index 00000000..15cb0251 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/input_data.h @@ -0,0 +1,42 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int8_t transpose_conv_4_input[672] = { + -62, 2, -77, -117, -82, 98, 120, -25, -45, -18, -62, 13, -2, -59, 120, 23, -84, 6, -67, + -119, 63, 53, 38, -53, -39, 109, 86, -38, -123, -99, 25, 4, -27, -63, -119, -24, -79, -47, + 1, 39, -60, 93, -47, 31, -23, 30, -72, 85, 66, -94, 69, 108, 121, -108, -34, 29, 0, + -105, -69, 118, -7, 56, -68, -119, 25, 63, 114, 18, 14, -3, -69, 114, 4, 112, 71, -3, + 75, 55, -93, 61, 27, 108, 35, 62, -121, -40, 126, -52, -106, 38, -18, -73, 46, -16, 94, + 102, 113, 125, -16, 65, -69, -118, 123, -47, -119, -99, 66, 51, 84, 2, 95, 102, 5, 72, + -83, -68, 126, 7, 92, -124, 33, 45, 17, 76, -7, 55, 107, 23, 121, -55, -79, 119, 60, + -54, -116, 10, 115, 77, 112, 80, -102, 57, -123, 40, 73, 27, -101, -101, 35, -36, 96, -22, + 58, 99, -60, -2, -3, 33, 101, 54, -108, -9, 5, 68, 82, -3, -74, -11, -12, -53, -95, + 64, -42, -51, 5, -90, -85, 110, -106, 50, 85, -100, -43, 17, -49, -113, -43, 107, 63, 57, + -122, 90, 105, 9, 113, -88, -65, 1, 2, 93, -117, 30, -4, -37, -101, 35, 79, 17, 34, + 26, -114, -124, 85, 93, -111, -90, -84, 56, 112, 37, 1, -93, 22, 21, 57, -108, 51, -28, + 126, -71, -17, -111, 1, -29, 100, 67, -120, -73, -43, 4, -53, -25, 39, -105, -121, -116, 105, + 62, 121, 42, 77, 73, 123, 101, 102, 70, -31, -58, -102, -62, -9, -82, -114, -56, 107, 125, + 79, 61, 63, 2, 81, 67, 116, 54, 32, -120, 62, 41, 106, 24, 114, -122, 49, 72, -8, + -95, 9, -109, -71, -55, 120, -74, -113, -57, -102, -62, -28, -35, 116, 115, 2, 77, -65, -66, + -41, 51, -13, 27, 71, 106, 45, -67, -8, -8, 59, 113, 86, 46, -30, -50, -34, 96, 43, + 16, 13, 64, 20, 63, -87, 36, -97, -67, -113, -43, 34, 76, -10, -75, 65, 55, 79, 96, + -58, 18, 111, 80, 73, -87, -85, 15, 103, -116, -101, -13, -16, 47, -99, 83, 125, -99, -122, + 87, 51, -117, -46, 34, -5, 95, -12, -41, 110, 54, 121, 5, -95, 69, -1, -128, 26, 12, + -16, 93, -87, 14, 19, -59, 11, -89, -2, -57, 107, 66, -102, -91, 28, -36, -12, 110, 17, + 47, -29, -83, 124, 103, 103, -57, 29, 24, -28, -48, -126, 77, 22, -120, -61, -11, 55, 94, + 47, 2, -87, -108, -112, -104, -87, 58, -74, 97, 55, -102, 76, -57, 117, 66, 31, -15, -126, + 104, 72, 73, 55, 38, -28, -68, -30, -87, -13, -84, -12, 8, 98, 80, 111, 111, 37, -128, + 42, -97, -73, 57, 20, -119, 122, 96, -82, -20, 111, 52, 31, 121, 25, -40, -109, 68, 118, + 16, -39, -34, -2, -112, -120, 94, 27, 68, -98, -51, -128, 120, 10, -17, 66, -20, -38, -119, + -25, -20, -90, -103, -105, -122, 74, -98, -128, 36, -104, 21, -124, 33, -19, -55, -106, -40, -4, + -13, 14, -26, 96, -119, -118, -26, 7, -86, 29, -31, 2, 9, -29, 64, 89, -111, 100, -110, + -26, 83, -78, -6, -68, -54, 48, -29, -86, 17, -77, 58, -52, -38, 50, 15, 110, 68, -29, + -67, 117, -99, -25, 71, -118, 30, -126, -80, -56, 38, 90, 93, 29, -112, 12, 13, -52, 64, + -46, 93, -29, -26, 116, 97, 74, -119, -90, -83, 27, -33, 87, -65, 67, 35, 109, -105, 7, + -125, 23, 16, -8, -116, -95, -90, 26, -83, 87, 78, -116, 22, 112, -86, 113, -31, -84, -60, + -81, -90, -113, 88, 48, -1, 2, 106, -4, 46, -88, 19, 111, 107, 97, 70, -104, 54, -112, + 28, -126, -107, -29, -8, 29, 106, 112, -89, 92, -57, -3, -50, -103, 69, -46, 70, 99, -20, + 21, 71, -56, -39, 110, 28, 4, -78, -69, -111, 75, -84, -51, -110, -50, -45, 69, 43, -111, + 115, -104, -95, 90, -67, 119, -118}; diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_mult_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_mult_data.h new file mode 100644 index 00000000..8481bb87 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_mult_data.h @@ -0,0 +1,6 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int32_t transpose_conv_4_output_mult[5] = {1860509430, 1864407918, 1864306093, 1851526341, 1864151479}; diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_ref_data.h new file mode 100644 index 00000000..4e63b193 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_ref_data.h @@ -0,0 +1,24 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int8_t transpose_conv_4_output_ref[405] = { + 10, 10, -37, -12, -26, 16, 9, 11, -23, -39, 12, -1, 28, 46, -28, -6, 41, -57, -17, 6, 49, -28, 6, + -5, 10, 17, -6, 46, 22, 17, -11, 63, -53, -87, -44, 16, 39, 48, 18, -11, 51, -62, 11, -21, -20, 16, + 85, -66, -43, -3, 57, 42, -21, 12, 8, 76, -59, -26, -11, -46, 3, 80, -98, -63, -24, 108, 43, 37, -57, + 14, 74, -60, -1, 6, -63, 12, 66, -74, -55, 27, 61, 43, 38, -4, 20, 48, -74, 2, -38, -84, 21, 34, + -31, -52, 11, 76, 47, 35, -41, -6, 66, -42, 10, -40, -42, 19, 76, -35, 9, 11, 74, -20, 3, 2, 73, + 27, -26, 5, -42, -21, 35, 13, -11, -4, -6, 36, 43, 62, 30, 24, 29, -7, -11, -27, -13, 10, 20, -15, + -27, -15, -8, 32, 13, -27, -37, 20, -14, 10, 58, 2, -16, 58, -34, -49, 13, 65, 43, -14, -42, -30, 35, + -25, 8, 0, -61, 14, 70, -53, -69, 11, 64, 67, 4, 7, 49, 78, -2, -47, -34, -34, 22, 41, -55, -52, + -16, 50, 45, 1, -35, -12, 64, -47, 7, -36, -18, -29, 79, -85, -49, -18, 23, 31, 20, 25, -2, 36, -22, + 17, -19, -51, 14, 51, -38, -51, 0, 70, 34, 50, -8, 18, 29, -28, 23, -24, -9, -8, 40, -30, -69, -29, + 24, 12, -7, 15, 44, 81, -41, 45, 20, -51, 34, 79, -26, -46, -2, 61, 33, -14, 59, 37, 48, -29, -32, + -48, -9, 55, 22, 8, -19, 21, 6, 63, 33, 13, 39, 26, 11, 15, -46, -10, 9, -8, -33, -23, -13, -1, + 8, 12, -39, -40, 10, -5, 38, 56, -9, -18, 40, -62, -42, -3, 43, -14, -28, 9, -31, 82, -44, 33, -16, + -8, 28, 54, -37, -17, 21, 7, 51, 68, -22, -13, 30, -50, 16, -32, -38, 52, 72, -17, -61, -5, 57, 43, + 12, -37, -22, 22, -63, 9, 20, -38, -21, 97, -40, -58, 44, 29, 64, 12, -47, -21, 71, -40, 30, -58, -73, + -5, 54, -102, -61, -46, 77, 42, 47, -21, 49, 16, 3, 13, 12, -49, 53, 61, -68, -49, -10, 54, 65, -18, + -50, 18, 68, -50, -10, -7, -52, 18, 73, 2, -2, 78, 91, 44, 25, -13, 27, 34, -37, -13, -80, -29, 52, + 17, -5, 17, -1, 43, 61, 51, -13, 22, 29, 30, 1, -29, -10}; diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_shift_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_shift_data.h new file mode 100644 index 00000000..5a7723b3 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/output_shift_data.h @@ -0,0 +1,6 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int32_t transpose_conv_4_output_shift[5] = {-11, -11, -11, -11, -11}; diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/test_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/test_data.h new file mode 100644 index 00000000..7de543ed --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/test_data.h @@ -0,0 +1,9 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#include "biases_data.h" +#include "config_data.h" +#include "input_data.h" +#include "output_mult_data.h" +#include "output_ref_data.h" +#include "output_shift_data.h" +#include "weights_data.h" diff --git a/Tests/UnitTest/TestCases/TestData/transpose_conv_4/weights_data.h b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/weights_data.h new file mode 100644 index 00000000..eea9ab06 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/transpose_conv_4/weights_data.h @@ -0,0 +1,82 @@ +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tensorflow version 2.15.0 and revision v2.15.0-rc1-8-g6887368d6d4. +#pragma once +#include + +const int8_t transpose_conv_4_weights[1440] = { + -98, 47, -84, -108, 21, 63, 28, -24, 20, -117, -1, 123, -78, -73, -109, -126, 119, -48, 78, + -69, 102, -1, -114, 24, -52, 119, 59, 91, 109, 14, -4, -57, -125, -9, -25, -120, 0, 88, + -65, 90, 21, -121, -4, -101, -47, 123, 50, -49, 73, 66, 74, -11, 111, -93, 80, -10, 36, + 88, -18, 17, -7, -39, -53, -84, 62, -95, -61, 93, -109, 34, -122, 68, 99, -109, 98, 12, + -113, -75, 62, 67, -34, -17, -66, -109, 42, 118, 10, -111, 60, 121, -112, 21, 79, 94, 23, + -40, 118, -88, 30, -125, -21, -27, -31, 115, -112, 68, -81, -33, 97, -21, -77, -98, -6, -2, + 78, -56, -81, 123, -3, 20, -28, 57, -92, -14, 63, 38, -93, 87, -39, 70, 22, 60, 71, + -17, 121, -116, -70, -49, 124, 49, -43, 103, -14, -8, 51, 123, -73, -95, -47, -30, 118, -68, + -16, 49, 73, 121, -92, 74, 80, -61, -14, -71, 84, 65, -40, 9, 73, 123, 63, -69, -54, + 105, 16, 61, 55, 18, -31, -35, 101, 127, -32, -16, 115, 12, -66, 5, -33, 121, 88, -91, + -86, -83, 123, 111, 85, -8, 11, 44, 64, -20, 99, 122, -99, 50, -39, -93, 46, 62, 58, + -43, -78, 23, 3, -96, 77, 116, -54, 72, -76, -33, 26, -124, 81, -88, 120, -49, -108, 53, + 67, 82, 46, 97, 17, 108, -41, -90, -52, 29, -117, -51, -23, 127, -32, 107, 26, -61, -110, + 55, -65, 110, -98, 59, 123, -33, 101, 42, -64, 86, -12, -90, -31, -36, 63, 25, -113, -84, + 74, 14, -76, 4, -15, 80, 117, -39, -37, -60, -76, 84, 64, -40, 107, 40, -1, 78, 110, + 120, -69, -7, 110, 48, 14, -125, -75, 56, 80, -101, 99, -68, 86, 1, -2, 37, -52, -124, + -91, 12, 106, -104, 66, -125, 81, -14, -35, -56, 45, -112, 119, -76, -48, 114, 7, -37, -107, + -3, 74, -49, -30, -114, 114, 117, 26, 32, 125, 89, 15, -29, -104, -25, 6, -41, 54, -5, + 52, 0, 66, -17, -127, -27, 88, 44, 13, 6, -72, -111, -90, -40, -119, 1, 51, 13, 102, + -35, -8, -116, 43, -30, 114, 117, -105, -62, 101, -30, 108, -17, -21, -106, 81, -113, -127, 67, + 103, 61, 49, -105, 87, 102, -7, -28, -107, -56, 101, 47, 24, 15, -67, 7, -15, 53, -35, + 29, 32, -96, 118, -67, 44, 60, 120, 25, 67, 124, 72, 8, 90, -19, 110, 39, -120, -12, + 31, 16, 102, -40, -89, -87, -52, -50, -27, 78, 45, -56, -110, -6, 122, -27, 103, 44, -77, + -76, -107, 0, 54, 15, -63, -99, 90, 71, -94, 61, 27, -56, -108, -112, -24, 33, -77, 52, + -26, -104, 62, 43, 111, 10, 15, -105, 45, -122, 68, 24, 76, -119, -27, 22, 105, -100, 74, + -87, -27, -110, -44, -100, 98, -24, -61, -49, 50, 45, -49, -62, 103, 2, -36, 105, 102, -11, + -35, 103, 41, -39, 16, -1, 10, -23, -18, 11, -94, 112, -46, 110, -25, 0, -69, -44, 110, + -44, 50, 46, 75, -31, 93, -23, 118, 105, -18, 93, 73, -60, -70, -42, 58, -82, 72, -91, + -45, 7, 104, -10, 69, 30, 106, 78, 107, -58, 9, -56, -23, -121, -106, 87, 107, -1, -22, + -116, -116, 45, 75, -6, -99, -98, -127, -67, 103, 125, 60, 83, 33, 48, 14, 67, -112, 44, + -55, -70, 24, -36, 74, 122, -94, -122, 125, -114, 20, -70, 20, -46, -65, -11, 88, -60, -118, + -83, -47, 38, 58, 46, 45, -38, -63, 36, -123, 67, 10, -103, -68, 100, -67, -91, 103, -106, + 17, -68, 33, 63, -52, -100, -99, 91, -2, -86, -105, -84, -34, 15, 19, 79, 19, 21, -110, + -61, 59, 89, 78, 31, -87, -83, 127, 43, 29, 54, 107, 35, 57, 38, -43, -4, 40, 56, + -70, 125, -19, 81, -70, -72, -4, -96, 80, 48, -24, 44, 96, -30, -19, 34, 101, 82, -20, + -15, 86, 10, 7, -73, -126, -92, -110, 63, -123, -48, 80, 1, -23, 3, 121, -42, -62, 84, + -92, 105, -57, -35, 68, -46, 77, -54, -124, -87, 17, 7, -10, -3, -118, 125, -10, -66, -106, + 17, -119, 117, -102, 26, -78, -25, 112, 21, -51, 6, 76, -101, -110, 0, -87, 0, -13, -91, + -92, -74, 67, 17, 68, 110, 22, 87, 13, -69, -109, 38, -41, -35, -42, -106, -116, 107, 27, + 125, 94, 82, -71, -93, -67, -12, 86, 78, -15, 91, -30, -64, -30, -100, 119, 113, -69, 47, + -28, 103, -125, 2, -62, 52, -58, -24, 4, 102, -108, -25, -10, 114, -92, 115, 45, 69, 53, + 90, -33, -97, 43, 57, -89, 7, 39, 123, -2, -54, -82, 7, 115, -65, -72, -119, -91, 14, + -75, 13, -45, 120, 68, -11, 100, 58, -6, 46, 50, 6, 63, 9, -44, -30, 81, -50, 31, + 23, -17, -89, 91, 88, -73, -63, -13, 98, 8, -64, -108, -50, 115, 86, -108, 118, -100, 66, + -90, -96, -76, -14, -21, -70, 12, 65, -97, -46, -10, 95, 119, 55, 55, 41, -47, 87, -81, + 44, -84, -17, -94, 16, -6, -35, -25, 66, -59, 73, -54, -126, 43, -5, -103, -33, 5, -90, + 69, 93, 25, -45, 88, 91, 45, -127, -108, -37, 67, -37, -115, -101, 117, -48, -42, -20, -103, + 93, -60, -49, -61, -24, -67, 37, -88, -105, -70, 121, 57, 124, -120, -86, 101, -62, -70, 94, + -23, 38, -70, 58, 1, -30, 35, -54, 104, -123, -77, 76, 3, 7, -78, 54, -45, 68, -117, + 20, 111, 110, -68, 17, -10, -17, 91, -6, 64, -14, 5, -70, -19, 120, 100, 120, -71, 50, + 109, -5, 104, -102, 114, 59, -67, -54, 57, 122, -102, -46, -3, -9, 26, 20, 52, 25, -111, + 40, -13, -127, -100, 78, -127, -72, -23, 81, -88, -70, 59, 73, 3, -105, -89, 101, -33, -26, + -125, 30, 63, 50, -21, -18, 9, 30, 13, 6, -81, 74, -58, 61, -127, 112, 115, -93, 5, + 15, 77, -42, 125, 109, -107, 87, -120, -33, 20, -126, 4, 5, -116, -121, -59, -98, -15, -12, + -62, 46, -70, -3, 111, -107, -107, -123, -64, -114, 3, -11, -34, 61, 38, 30, -125, -50, 45, + 75, 116, -84, -98, -105, 96, -104, -75, -50, -53, 37, -85, -88, 20, 0, 18, 113, 6, 57, + 109, -108, -32, -38, -51, -27, 3, 48, -96, -28, -72, -46, 32, -4, 74, 106, -27, 30, -34, + -12, 51, 24, 24, -42, 64, -44, -3, -74, -105, 52, 54, -85, -97, -52, -84, 72, -25, 104, + 0, 58, 36, 5, 89, 38, 81, -95, -72, -55, 109, 18, 31, 126, 60, 102, -22, 1, 21, + -126, 63, 94, -88, 26, 53, -109, 65, 55, 111, -63, -6, 21, -100, 35, -32, -98, 5, 66, + 102, -116, -15, -29, -118, -86, -2, -64, -71, 19, -93, 16, 23, -95, 106, -90, -94, -102, 28, + 0, 76, -107, -24, 42, -100, -125, -54, -6, -32, -43, 121, 43, 119, 50, -2, 85, -49, -116, + -104, 33, 31, -77, -87, 70, -52, -108, 1, -38, -54, -103, -32, 62, -78, -71, -121, 28, 73, + 83, 59, -53, 38, 9, -8, 51, -47, -63, 31, -8, -36, -16, -97, -86, 120, 68, -71, -116, + 10, -50, -31, -117, -86, -83, -108, -31, -113, -70, -92, 92, -78, 7, 33, 118, -75, -123, 109, + 87, -100, -76, -103, 6, -72, 124, 5, 90, 108, 116, 74, -127, -107, -20, -3, 85, 80, -19, + 1, 2, 75, 28, 29, 83, -95, -38, -93, 27, -73, -9, -105, -56, 68, -118, 72, 0, 28, + 101, -22, 2, 77, -121, 98, 59, 63, -84, -44, -75, -40, 121, 55, 98, -43, 45, 44, -85, + 27, 55, 69, 30, 66, -13, 5, -114, -4, 102, -23, -60, -44, -88, -105, 55, 90, 38, 6, + 78, -104, 42, 32, -85, -50, -94, 60, 30, -98, 5, -22, 99, 53, -50, 43, 70, 76, 12, + 61, -4, 42, -44, -115, -72, -125, 28, 83, 42, -50, 59, 57, -126, 40, 21, -112, -12, 52, + -114, -59, 32, 107, 88, -87, 119, 83, 68, 10, 77, -108, -110, 22, 48, -101, -2, -67, 97, + 31, 51, 10, -2, -40, 21, -88, 56, 112, -115, 65, 47, -90, 25, 21, -15, -24, -30, 33, + -108, 1, -26, -30, -43, 93, 121, 88, 48, -74, -50, 105, 61, 6, 20, -41, 87, 106, 70, + -71, 77, 94, -61, -105, -74, -74, 108, -48, 48, -12, -43, -64, 66, -125, -35, 59, 23, 16, + -75, 50, -73, 62, -24, 109, -36, 42, 11, -106, -85, -82, 28, 50, -75}; diff --git a/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/Unity/unity_test_arm_transpose_conv_s8.c b/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/Unity/unity_test_arm_transpose_conv_s8.c index b671a216..fd51015c 100644 --- a/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/Unity/unity_test_arm_transpose_conv_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/Unity/unity_test_arm_transpose_conv_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -46,3 +46,4 @@ void tearDown(void) {} void test_transpose_conv_1_arm_transpose_conv_s8(void) { transpose_conv_1_arm_transpose_conv_s8(); } void test_transpose_conv_2_arm_transpose_conv_s8(void) { transpose_conv_2_arm_transpose_conv_s8(); } void test_transpose_conv_3_arm_transpose_conv_s8(void) { transpose_conv_3_arm_transpose_conv_s8(); } +void test_transpose_conv_4_arm_transpose_conv_s8(void) { transpose_conv_3_arm_transpose_conv_s8(); } diff --git a/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/test_arm_transpose_conv_s8.c b/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/test_arm_transpose_conv_s8.c index 8d82b045..5b6e57a8 100644 --- a/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/test_arm_transpose_conv_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_transpose_conv_s8/test_arm_transpose_conv_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -22,6 +22,7 @@ #include "../TestData/transpose_conv_1/test_data.h" #include "../TestData/transpose_conv_2/test_data.h" #include "../TestData/transpose_conv_3/test_data.h" +#include "../TestData/transpose_conv_4/test_data.h" #include "../Utils/utils.h" #include "../Utils/validate.h" @@ -280,3 +281,88 @@ void transpose_conv_3_arm_transpose_conv_s8(void) TEST_ASSERT_EQUAL(expected, result); TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); } + +void transpose_conv_4_arm_transpose_conv_s8(void) +{ + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + int8_t output[TRANSPOSE_CONV_4_DST_SIZE] = {0}; + + cmsis_nn_context ctx; + cmsis_nn_context output_ctx; + cmsis_nn_transpose_conv_params transpose_conv_params; + cmsis_nn_per_channel_quant_params quant_params; + cmsis_nn_dims input_dims; + cmsis_nn_dims filter_dims; + cmsis_nn_dims bias_dims = {0}; + cmsis_nn_dims output_dims; + + const int32_t *bias_data = transpose_conv_4_biases; + const int8_t *kernel_data = transpose_conv_4_weights; + const int8_t *input_data = transpose_conv_4_input; + const int8_t *output_ref = transpose_conv_4_output_ref; + const int32_t output_ref_size = TRANSPOSE_CONV_4_DST_SIZE; + + input_dims.n = TRANSPOSE_CONV_4_INPUT_BATCHES; + input_dims.w = TRANSPOSE_CONV_4_INPUT_W; + input_dims.h = TRANSPOSE_CONV_4_INPUT_H; + input_dims.c = TRANSPOSE_CONV_4_IN_CH; + filter_dims.w = TRANSPOSE_CONV_4_FILTER_X; + filter_dims.h = TRANSPOSE_CONV_4_FILTER_Y; + output_dims.n = TRANSPOSE_CONV_4_INPUT_BATCHES; + output_dims.w = TRANSPOSE_CONV_4_OUTPUT_W; + output_dims.h = TRANSPOSE_CONV_4_OUTPUT_H; + output_dims.c = TRANSPOSE_CONV_4_OUT_CH; + + output_ctx.size = output_dims.w * output_dims.h * output_dims.c * sizeof(int32_t); + output_ctx.buf = malloc(output_ctx.size); + + transpose_conv_params.padding.w = TRANSPOSE_CONV_4_PAD_X; + transpose_conv_params.padding.h = TRANSPOSE_CONV_4_PAD_Y; + transpose_conv_params.padding_offsets.w = TRANSPOSE_CONV_4_PAD_X_WITH_OFFSET; + transpose_conv_params.padding_offsets.h = TRANSPOSE_CONV_4_PAD_Y_WITH_OFFSET; + + transpose_conv_params.stride.w = TRANSPOSE_CONV_4_STRIDE_X; + transpose_conv_params.stride.h = TRANSPOSE_CONV_4_STRIDE_Y; + transpose_conv_params.dilation.w = TRANSPOSE_CONV_4_DILATION_X; + transpose_conv_params.dilation.h = TRANSPOSE_CONV_4_DILATION_Y; + + transpose_conv_params.input_offset = TRANSPOSE_CONV_4_INPUT_OFFSET; + transpose_conv_params.output_offset = TRANSPOSE_CONV_4_OUTPUT_OFFSET; + transpose_conv_params.activation.min = TRANSPOSE_CONV_4_OUT_ACTIVATION_MIN; + transpose_conv_params.activation.max = TRANSPOSE_CONV_4_OUT_ACTIVATION_MAX; + quant_params.multiplier = (int32_t *)transpose_conv_4_output_mult; + quant_params.shift = (int32_t *)transpose_conv_4_output_shift; + + const int32_t buf_size = arm_transpose_conv_s8_get_buffer_size(&input_dims, &filter_dims, &output_dims); + ctx.buf = malloc(buf_size); + ctx.size = buf_size; + + arm_cmsis_nn_status result = arm_transpose_conv_s8(&ctx, + &output_ctx, + &transpose_conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + kernel_data, + &bias_dims, + bias_data, + &output_dims, + output); + + if (output_ctx.buf) + { + // The caller is responsible to clear the scratch buffers for security reasons if applicable. + memset(output_ctx.buf, 0, output_ctx.size); + free(output_ctx.buf); + } + + if (ctx.buf) + { + // The caller is responsible to clear the scratch buffers for security reasons if applicable. + memset(ctx.buf, 0, buf_size); + free(ctx.buf); + } + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); +} diff --git a/Tests/UnitTest/generate_test_data.py b/Tests/UnitTest/generate_test_data.py index c1e6ca6d..c4790136 100755 --- a/Tests/UnitTest/generate_test_data.py +++ b/Tests/UnitTest/generate_test_data.py @@ -1336,6 +1336,25 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases, stride_y=5, pad=True, interpreter=interpreter) + dataset = 'transpose_conv_4' + testdata_sets[dataset] = ConvSettings(dataset, + type_of_test, + regenerate_weights, + regenerate_input, + regenerate_biases, + schema_file, + in_ch=32, + batches=3, + out_ch=5, + x_in=1, + y_in=7, + w_x=3, + w_y=3, + generate_bias=False, + stride_x=3, + stride_y=1, + pad=False, + interpreter=interpreter) type_of_test = 'depthwise_conv' dataset = 'depthwise_2'