From 601d96c63a4e1437d9a90c0f917684405e60d8a1 Mon Sep 17 00:00:00 2001 From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:09:56 +0100 Subject: [PATCH] Reimplement arm_lstm_unidirectional_s8 (#102) - API changes - Optimized for scalar, DSP and MVE - Bit exact to TFLM reference kernel - Less scratch-buffer usage --- ARM.CMSIS-NN.pdsc | 9 +- Include/arm_nn_types.h | 134 ++--- Include/arm_nnfunctions.h | 101 ++-- Include/arm_nnsupportfunctions.h | 215 ++++---- .../arm_nn_activation_s16.c | 38 +- .../arm_elementwise_mul_acc_s16.c | 167 ++++++ .../arm_elementwise_mul_s16_s8.c | 100 ++-- .../arm_vector_sum_s8.c | 97 ++-- Source/LSTMFunctions/CMakeLists.txt | 6 +- .../arm_lstm_unidirectional_s8.c | 80 +++ .../arm_lstm_unidirectional_s8_s16.c | 184 ------- .../arm_nn_lstm_calculate_gate_s8_s16.c | 80 ++- .../NNSupportFunctions/arm_nn_lstm_step_s8.c | 110 ++++ .../arm_nn_lstm_step_s8_s16.c | 154 ------ .../arm_nn_lstm_update_cell_state_s16.c | 124 ----- .../arm_nn_lstm_update_output_s8_s16.c | 81 --- ...=> arm_nn_vec_mat_mul_result_acc_s8_s16.c} | 220 ++++---- Tests/UnitTest/CMakeLists.txt | 2 +- Tests/UnitTest/README.md | 16 +- .../TestData/lstm_1/cell_gate_bias_data.h | 6 +- .../TestData/lstm_1/cell_norm_coeff_data.h | 4 +- .../TestData/lstm_1/cell_state_data.h | 4 +- .../TestData/lstm_1/cell_to_forget_data.h | 4 +- .../TestData/lstm_1/cell_to_input_data.h | 4 +- .../TestData/lstm_1/cell_to_output_data.h | 4 +- .../TestCases/TestData/lstm_1/config_data.h | 31 +- .../TestData/lstm_1/forget_gate_bias_data.h | 6 +- .../TestData/lstm_1/forget_norm_coeff_data.h | 4 +- .../TestCases/TestData/lstm_1/input_data.h | 26 +- .../TestData/lstm_1/input_gate_bias_data.h | 7 +- .../TestData/lstm_1/input_norm_coeff_data.h | 4 +- .../lstm_1/input_to_cell_eff_bias_data.h | 6 +- .../TestData/lstm_1/input_to_cell_w_data.h | 29 +- .../lstm_1/input_to_forget_eff_bias_data.h | 6 +- .../TestData/lstm_1/input_to_forget_w_data.h | 30 +- .../lstm_1/input_to_input_eff_bias_data.h | 6 +- .../TestData/lstm_1/input_to_input_w_data.h | 30 +- .../lstm_1/input_to_output_eff_bias_data.h | 6 +- .../TestData/lstm_1/input_to_output_w_data.h | 30 +- .../TestData/lstm_1/output_gate_bias_data.h | 6 +- .../TestData/lstm_1/output_norm_coeff_data.h | 4 +- .../TestData/lstm_1/output_ref_data.h | 14 +- .../TestData/lstm_1/output_state_data.h | 6 +- .../TestData/lstm_1/projection_bias_data.h | 4 +- .../TestData/lstm_1/projection_weights_data.h | 4 +- .../lstm_1/recurrent_input_to_cell_w_data.h | 17 +- .../lstm_1/recurrent_input_to_forget_w_data.h | 16 +- .../lstm_1/recurrent_input_to_input_w_data.h | 16 +- .../lstm_1/recurrent_input_to_output_w_data.h | 17 +- .../lstm_1/recurrent_to_cell_eff_bias_data.h | 6 +- .../recurrent_to_forget_eff_bias_data.h | 6 +- .../lstm_1/recurrent_to_input_eff_bias_data.h | 6 +- .../recurrent_to_output_eff_bias_data.h | 6 +- .../TestCases/TestData/lstm_1/test_data.h | 4 +- .../TestData/lstm_2/cell_gate_bias_data.h | 6 +- .../TestData/lstm_2/cell_norm_coeff_data.h | 4 +- .../TestData/lstm_2/cell_state_data.h | 6 +- .../TestData/lstm_2/cell_to_forget_data.h | 4 +- .../TestData/lstm_2/cell_to_input_data.h | 4 +- .../TestData/lstm_2/cell_to_output_data.h | 4 +- .../TestCases/TestData/lstm_2/config_data.h | 43 +- .../TestData/lstm_2/forget_gate_bias_data.h | 6 +- .../TestData/lstm_2/forget_norm_coeff_data.h | 4 +- .../TestCases/TestData/lstm_2/input_data.h | 15 +- .../TestData/lstm_2/input_gate_bias_data.h | 6 +- .../TestData/lstm_2/input_norm_coeff_data.h | 4 +- .../lstm_2/input_to_cell_eff_bias_data.h | 6 +- .../TestData/lstm_2/input_to_cell_w_data.h | 10 +- .../lstm_2/input_to_forget_eff_bias_data.h | 6 +- .../TestData/lstm_2/input_to_forget_w_data.h | 10 +- .../lstm_2/input_to_input_eff_bias_data.h | 6 +- .../TestData/lstm_2/input_to_input_w_data.h | 10 +- .../lstm_2/input_to_output_eff_bias_data.h | 6 +- .../TestData/lstm_2/input_to_output_w_data.h | 10 +- .../TestData/lstm_2/output_gate_bias_data.h | 6 +- .../TestData/lstm_2/output_norm_coeff_data.h | 4 +- .../TestData/lstm_2/output_ref_data.h | 15 +- .../TestData/lstm_2/output_state_data.h | 6 +- .../TestData/lstm_2/projection_bias_data.h | 4 +- .../TestData/lstm_2/projection_weights_data.h | 4 +- .../lstm_2/recurrent_input_to_cell_w_data.h | 12 +- .../lstm_2/recurrent_input_to_forget_w_data.h | 10 +- .../lstm_2/recurrent_input_to_input_w_data.h | 10 +- .../lstm_2/recurrent_input_to_output_w_data.h | 10 +- .../lstm_2/recurrent_to_cell_eff_bias_data.h | 6 +- .../recurrent_to_forget_eff_bias_data.h | 6 +- .../lstm_2/recurrent_to_input_eff_bias_data.h | 6 +- .../recurrent_to_output_eff_bias_data.h | 6 +- .../TestCases/TestData/lstm_2/test_data.h | 4 +- .../lstm_one_time_step/cell_gate_bias_data.h | 6 +- .../lstm_one_time_step/cell_norm_coeff_data.h | 4 +- .../lstm_one_time_step/cell_state_data.h | 4 +- .../lstm_one_time_step/cell_to_forget_data.h | 4 +- .../lstm_one_time_step/cell_to_input_data.h | 4 +- .../lstm_one_time_step/cell_to_output_data.h | 4 +- .../TestData/lstm_one_time_step/config_data.h | 37 +- .../forget_gate_bias_data.h | 6 +- .../forget_norm_coeff_data.h | 4 +- .../TestData/lstm_one_time_step/input_data.h | 11 +- .../lstm_one_time_step/input_gate_bias_data.h | 6 +- .../input_norm_coeff_data.h | 4 +- .../input_to_cell_eff_bias_data.h | 6 +- .../lstm_one_time_step/input_to_cell_w_data.h | 10 +- .../input_to_forget_eff_bias_data.h | 6 +- .../input_to_forget_w_data.h | 10 +- .../input_to_input_eff_bias_data.h | 6 +- .../input_to_input_w_data.h | 10 +- .../input_to_output_eff_bias_data.h | 6 +- .../input_to_output_w_data.h | 11 +- .../output_gate_bias_data.h | 6 +- .../output_norm_coeff_data.h | 4 +- .../lstm_one_time_step/output_ref_data.h | 6 +- .../lstm_one_time_step/output_state_data.h | 6 +- .../lstm_one_time_step/projection_bias_data.h | 4 +- .../projection_weights_data.h | 4 +- .../recurrent_input_to_cell_w_data.h | 6 +- .../recurrent_input_to_forget_w_data.h | 6 +- .../recurrent_input_to_input_w_data.h | 6 +- .../recurrent_input_to_output_w_data.h | 6 +- .../recurrent_to_cell_eff_bias_data.h | 6 +- .../recurrent_to_forget_eff_bias_data.h | 6 +- .../recurrent_to_input_eff_bias_data.h | 6 +- .../recurrent_to_output_eff_bias_data.h | 6 +- .../TestData/lstm_one_time_step/test_data.h | 4 +- .../test_arm_ds_cnn_l_s8.c | 2 +- .../test_arm_fully_connected_s8.c | 12 +- .../test_arm_lstm_unidirectional_s16_s8.c | 328 ------------ .../CMakeLists.txt | 10 +- .../unity_test_arm_lstm_unidirectional_s8.c} | 13 +- .../test_arm_lstm_unidirectional_s8.c | 495 ++++++++++++++++++ .../test_arm_svdf_s8/test_arm_svdf_s8.c | 6 +- Tests/UnitTest/generate_test_data.py | 5 +- Tests/UnitTest/lstm_settings.py | 31 +- 133 files changed, 1830 insertions(+), 1889 deletions(-) create mode 100644 Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c create mode 100644 Source/LSTMFunctions/arm_lstm_unidirectional_s8.c delete mode 100644 Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c create mode 100644 Source/NNSupportFunctions/arm_nn_lstm_step_s8.c delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c rename Source/NNSupportFunctions/{arm_nn_vec_mat_mul_result_acc_s8.c => arm_nn_vec_mat_mul_result_acc_s8_s16.c} (60%) delete mode 100644 Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c rename Tests/UnitTest/TestCases/{test_arm_lstm_unidirectional_s16_s8 => test_arm_lstm_unidirectional_s8}/CMakeLists.txt (67%) rename Tests/UnitTest/TestCases/{test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c => test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c} (69%) create mode 100644 Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c diff --git a/ARM.CMSIS-NN.pdsc b/ARM.CMSIS-NN.pdsc index f216d8cc..03d3783d 100644 --- a/ARM.CMSIS-NN.pdsc +++ b/ARM.CMSIS-NN.pdsc @@ -83,6 +83,7 @@ + @@ -107,18 +108,16 @@ - - - + - + - + diff --git a/Include/arm_nn_types.h b/Include/arm_nn_types.h index 257442da..c567f0c1 100644 --- a/Include/arm_nn_types.h +++ b/Include/arm_nn_types.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -22,8 +22,8 @@ * Description: Public header file to contain the CMSIS-NN structs for the * TensorFlowLite micro compliant functions * - * $Date: 9 January 2024 - * $Revision: V.2.6.2 + * $Date: 19 January 2024 + * $Revision: V.3.0.0 * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ @@ -40,7 +40,6 @@ * @{ */ - /** Enum for specifying activation function types */ typedef enum { @@ -180,31 +179,6 @@ typedef struct const int16_t *one_by_one_lut; } cmsis_nn_softmax_lut_s16; -/** LSTM guard parameters */ -typedef struct -{ - int32_t input_variance; - int32_t forget_variance; - int32_t cell_variance; - int32_t output_variance; -} cmsis_nn_lstm_guard_params; - -/** LSTM scratch buffer container */ -typedef struct -{ - int16_t *input_gate; - int16_t *forget_gate; - int16_t *cell_gate; - int16_t *output_gate; -} cmsis_nn_lstm_context; - -/** Quantized clip value for cell and projection of LSTM input. Zero value means no clipping. */ -typedef struct -{ - int16_t cell; - int8_t projection; -} cmsis_nn_lstm_clip_params; - /** CMSIS-NN object for quantization parameters */ typedef struct { @@ -212,70 +186,60 @@ typedef struct int32_t shift; /**< Shift value */ } cmsis_nn_scaling; -/** CMSIS-NN norm layer coefficients */ +/** CMSIS-NN object for LSTM gate parameters*/ typedef struct { - int16_t *input_weight; - int16_t *forget_weight; - int16_t *cell_weight; - int16_t *output_weight; -} cmsis_nn_layer_norm; + int32_t input_multiplier; + int32_t input_shift; + const int8_t *input_weights; + const int32_t *input_effective_bias; /**< Bias added with precomputed kernel_sum * lhs_offset*/ -/** Parameters for integer LSTM, as defined in TFLM */ + int32_t hidden_multiplier; + int32_t hidden_shift; + const int8_t *hidden_weights; + const int32_t *hidden_effective_bias; /**< Precomputed kernel_sum * lhs_offset*/ + + const int32_t *bias; + arm_nn_activation_type activation_type; +} cmsis_nn_lstm_gate; + +/** CMSIS-NN object for LSTM parameters*/ typedef struct { - int32_t time_major; /**< Nonzero (true) if first row of data is timestamps for input */ - cmsis_nn_scaling input_to_input_scaling; - cmsis_nn_scaling input_to_forget_scaling; - cmsis_nn_scaling input_to_cell_scaling; - cmsis_nn_scaling input_to_output_scaling; - cmsis_nn_scaling recurrent_to_input_scaling; - cmsis_nn_scaling recurrent_to_forget_scaling; - cmsis_nn_scaling recurrent_to_cell_scaling; - cmsis_nn_scaling recurrent_to_output_scaling; - cmsis_nn_scaling cell_to_input_scaling; - cmsis_nn_scaling cell_to_forget_scaling; - cmsis_nn_scaling cell_to_output_scaling; - cmsis_nn_scaling projection_scaling; - cmsis_nn_scaling hidden_scaling; - cmsis_nn_scaling layer_norm_input_scaling; /**< layer normalization for input layer */ - cmsis_nn_scaling layer_norm_forget_scaling; /**< layer normalization for forget gate */ - cmsis_nn_scaling layer_norm_cell_scaling; /**< layer normalization for cell */ - cmsis_nn_scaling layer_norm_output_scaling; /**< layer normalization for outpus layer */ - - int32_t cell_state_shift; - int32_t hidden_offset; - int32_t output_state_offset; - - cmsis_nn_lstm_clip_params clip; - cmsis_nn_lstm_guard_params guard; - cmsis_nn_layer_norm layer_norm; - - /* Effective bias is precalculated as bias + zero_point * weight. - Only applicable to when input/output are s8 and weights are s16 */ - const int32_t *i2i_effective_bias; /**< input to input effective bias */ - const int32_t *i2f_effective_bias; /**< input to forget gate effective bias */ - const int32_t *i2c_effective_bias; /**< input to cell effective bias */ - const int32_t *i2o_effective_bias; /**< input to output effective bias */ - - const int32_t *r2i_effective_bias; /**< recurrent gate to input effective bias */ - const int32_t *r2f_effective_bias; /**< recurrent gate to forget gate effective bias */ - const int32_t *r2c_effective_bias; /**< recurrent gate to cell effective bias */ - const int32_t *r2o_effective_bias; /**< recurrent gate to output effective bias */ - - const int32_t *projection_effective_bias; - - /* Not precalculated bias */ - const int32_t *input_gate_bias; - const int32_t *forget_gate_bias; - const int32_t *cell_gate_bias; - const int32_t *output_gate_bias; - - /* Activation min and max */ - cmsis_nn_activation activation; + int32_t time_major; /**< 0 if first dimension is batch, else first dimension is time */ + int32_t batch_size; + int32_t time_steps; + int32_t input_size; /**< Size of new data input into the LSTM cell*/ + int32_t + hidden_size; /**< Size of output from the LSTM cell, used as output and recursively into the next time step*/ + + int32_t input_offset; + + int32_t forget_to_cell_multiplier; + int32_t forget_to_cell_shift; + int32_t input_to_cell_multiplier; + int32_t input_to_cell_shift; + int32_t cell_clip; /**< Min/max value of cell output*/ + int32_t cell_scale_power; + int32_t output_multiplier; + int32_t output_shift; + int32_t output_offset; + + cmsis_nn_lstm_gate forget_gate; + cmsis_nn_lstm_gate input_gate; + cmsis_nn_lstm_gate cell_gate; + cmsis_nn_lstm_gate output_gate; } cmsis_nn_lstm_params; +/** CMSIS-NN object for LSTM scratch buffers*/ +typedef struct +{ + void *temp1; + void *temp2; + void *cell_state; +} cmsis_nn_lstm_context; + /** * @} // end group genPubTypes */ diff --git a/Include/arm_nnfunctions.h b/Include/arm_nnfunctions.h index ade3417c..70a106d9 100644 --- a/Include/arm_nnfunctions.h +++ b/Include/arm_nnfunctions.h @@ -21,8 +21,9 @@ * Title: arm_nnfunctions.h * Description: Public header file for CMSIS NN Library * - * $Date: 11 January 2024 - * $Revision: V.12.6.0 + * $Date: 19 January 2024 + * $Revision: V.13.0.0 + * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ @@ -1514,11 +1515,13 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx, int8_t *output_data); /** - * @brief Calculate vector sums that may be required by arm_fully_connected_s8(). + * @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add bias_data. * @param[in, out] vector_sum_buf Buffer for vector sums * @param[in] vector_cols Number of vector columns * @param[in] vector_rows Number of vector rows - * @param[in] vector_data Vector or weigths data + * @param[in] vector_data Vector of weigths data + * @param[in] lhs_offset Constant multiplied with each sum + * @param[in] bias_data Vector of bias data, added to each sum. * @return The function returns * ARM_CMSIS_NN_SUCCESS - Successful operation * ARM_CMSIS_NN_ARG_ERROR - If not for Arm(R) Helium Architecture case. @@ -1526,7 +1529,9 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx, arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf, const int32_t vector_cols, const int32_t vector_rows, - const int8_t *vector_data); + const int8_t *vector_data, + const int32_t lhs_offset, + const int32_t *bias_data); /** * @brief Get size of additional buffer required by arm_fully_connected_s8(). @@ -1802,21 +1807,23 @@ void arm_relu_q15(int16_t *data, uint16_t size); /** * @brief s16 neural network activation function using direct table look-up - * @param[in] input pointer to input data + * @param[in] input pointer to input data * @param[out] output pointer to output * @param[in] size number of elements - * @param[in] left_shift bit-width of the integer part, assume to be smaller than 3 + * @param[in] left_shift bit-width of the integer part, assumed to be smaller than 3. * @param[in] type type of activation functions + * @return The function returns ARM_CMSIS_NN_SUCCESS + * * @details Supported framework: TensorFlow Lite for Microcontrollers. - * This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid actication + * This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid activation * functions */ -void arm_nn_activation_s16(const int16_t *input, - int16_t *output, - const uint16_t size, - const uint16_t left_shift, - const arm_nn_activation_type type); +arm_cmsis_nn_status arm_nn_activation_s16(const int16_t *input, + int16_t *output, + const int32_t size, + const int32_t left_shift, + const arm_nn_activation_type type); /** * @defgroup Pooling Pooling Functions @@ -2441,67 +2448,25 @@ arm_cmsis_nn_status arm_svdf_state_s16_s8(const cmsis_nn_context *input_ctx, */ /** - * @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output - * Peephole connections, projection, clipping, combined input/forget gate and layer normalization are not supported. - * - * @param[in] scratch_buffers Struct containing scratch buffers - * Expected size for each scratch buffer is - * lstm_dims->num_batches * lstm_dims->num_outputs. - * @param[in] input_data Pointer to input data - * @param[in] lstm_dims LSTM input parameters related to dimensions - * @param[in] input_to_input_weights Input to input weights - * @param[in] input_to_forget_weights Input to forget weights - * @param[in] input_to_cell_weights Input to cell weights - * @param[in] input_to_output_weights Input to output weights - * @param[in] recurrent_to_input_weights Recurrent to input weights - * @param[in] recurrent_to_forget_weights Recurrent to forget weights - * @param[in] recurrent_to_cell_weights Recurrent to cell weights - * @param[in] recurrent_to_output_weights Recurrent to output weights - * @param[in] cell_to_input_weights Cell to input weights. Not used. - * @param[in] cell_to_forget_weights Cell to forget weights. Not used. - * @param[in] cell_to_output_weights Cell to output weights. Not used. - * @param[in] projection_weights Projection weights. Not used. - * @param[in] lstm LSTM parameters. See struct declaration - * @param[in] output_state Pointer to (recurrent) output state - * @param[in] cell_state Pointer to cell state - * @param[in] output_data Pointer to output state - * - * @note Following assumptions are done based on LSTM functionality as supported by - * Keras version 2.9.0 at the time of development. As stated here, - * https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md - * Keras's LSTMCell is equivalent to TensorFlow's BasicLSTMCell, - * which does not support peephole, clipping or projection. - * Layer normalization and combined input/forget gate are not supported either. - * - * 1 Input to input weight can not be nullptr. Otherwise nullptr for combined input/forgat gate. - * 2 Cell weights are not used and should be nullptr. Otherwise needed for peephole connections. - * 3 Projection weight is not used and should be nullpr. Otherwise needed for projection. + * @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output. + * + * @param[in] input Pointer to input data + * @param[out] output Pointer to output data + * @param[in] params Struct containing all information about the lstm operator, see arm_nn_types. + * @param[in] buffers Struct containing pointers to all temporary scratch buffers needed for the + * lstm operator, see arm_nn_types. + * * * @return The function returns ARM_CMSIS_NN_SUCCESS * * @details - * 1. Supported framework: TensorFlow Lite micro + * 1. Supported framework: TensorFlow Lite Micro * */ -arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers, - const int8_t *input_data, - const cmsis_nn_lstm_dims *lstm_dims, - const int8_t *input_to_input_weights, - const int8_t *input_to_forget_weights, - const int8_t *input_to_cell_weights, - const int8_t *input_to_output_weights, - const int8_t *recurrent_to_input_weights, - const int8_t *recurrent_to_forget_weights, - const int8_t *recurrent_to_cell_weights, - const int8_t *recurrent_to_output_weights, - const int16_t *cell_to_input_weights, - const int16_t *cell_to_forget_weights, - const int16_t *cell_to_output_weights, - const int8_t *projection_weights, - const cmsis_nn_lstm_params *lstm, - int8_t *output_state, - int16_t *cell_state, - int8_t *output_data); +arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input, + int8_t *output, + const cmsis_nn_lstm_params *params, + cmsis_nn_lstm_context *buffers); /** * @brief Get size of additional buffer required by arm_svdf_s8(). diff --git a/Include/arm_nnsupportfunctions.h b/Include/arm_nnsupportfunctions.h index 20cbfd38..f2393db0 100644 --- a/Include/arm_nnsupportfunctions.h +++ b/Include/arm_nnsupportfunctions.h @@ -21,8 +21,8 @@ * Title: arm_nnsupportfunctions.h * Description: Public header file of support functions for CMSIS NN Library * - * $Date: 11 January 2024 - * $Revision: V.17.7.0 + * $Date: 19 January 2024 + * $Revision: V.18.0.0 * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ @@ -1453,142 +1453,77 @@ __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src) /** * @brief Update LSTM function for an iteration step * - * param[in] input Input data - * param[in] input_to_input_weight Input to input gate weights - * param[in] input_to_forget_weight Input to forget gate weights - * param[in] input_to_cell_weight Input to cell gate weights - * param[in] input_to_output_weight Input to output weights - * param[in] recurrent_to_input_weight Recurrent signal to input weights - * param[in] recurrent_to_forget_weight Recurrent signal to forget gate weights - * param[in] recurrent_to_cell_weight Recurrent signal to cell gate weighst - * param[in] recurrent_to_output_weight Recurrent signal to output weights - * param[in] lstm LSTM parameters - * param[in] n_batch Batch size - * param[in] n_cell Cell size - * param[in] n_input Input size - * param[in] n_output Output size - * param[out] output_state Output state - * param[out] cell_state Internal state - * param[out] output Output signal - * param[in] *scratch_buffers Struct containing scratch buffers - */ -arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input, - const int8_t *input_to_input_weight, - const int8_t *input_to_forget_weight, - const int8_t *input_to_cell_weight, - const int8_t *input_to_output_weight, - const int8_t *recurrent_to_input_weight, - const int8_t *recurrent_to_forget_weight, - const int8_t *recurrent_to_cell_weight, - const int8_t *recurrent_to_output_weight, - const cmsis_nn_lstm_params *lstm, - const int n_batch, - const int n_cell, - const int n_input, - const int n_output, - int8_t *output_state, - int16_t *cell_state, - int8_t *output, - cmsis_nn_lstm_context *scratch_buffers); + * @param[in] data_in Data input pointervoid + * @param[in] hidden_in Hidden state/ recurrent input pointer + * @param[out] hidden_out Hidden state/ recurrent output pointer + * @param[in] params Struct containg all information about the lstm operator, see + * arm_nn_types. + * @param[in] buffers Struct containg pointers to all temporary scratch buffers needed for the + * lstm operator, see arm_nn_types. + * @param[in] batch_offset Number of timesteps between consecutive batches. + * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1. + * For params->time major = false, all time steps are stored continously before the next batch, so + * batch offset = params->time_steps. + * @return The function returns ARM_CMSIS_NN_SUCCESS -/** - * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version. - * - * param[in] input Input data - * param[in] input_to_gate_weights Input to gate weights - * param[in] input_to_gate_bias Input to gate weights - * param[in] input_to_gate_scaling Input to gate scaling - * param[in] activation Actival min and max values - * param[in] output_state Output state - * param[in] recurrent_to_gate_weights Recurrent to gate weights - * param[in] recurrent_to_gate_bias Recurrent to gate bias - * param[in] recurrent_to_gate_scaling Recurrent to gate scaling - * param[in] n_batch Batch size - * param[in] n_input Input size - * param[out] n_output Output size - * param[in] activation_type Activation type (sigmoid or tanh) - * param[out] n_cell Cell size - */ -void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input, - const int8_t *input_to_gate_weights, - const int32_t *input_to_gate_bias, - const cmsis_nn_scaling input_to_gate_scaling, - const int8_t *output_state, - const int8_t *recurrent_to_gate_weights, - const int32_t *recurrent_to_gate_bias, - const cmsis_nn_scaling recurrent_to_gate_scaling, - const int32_t n_batch, - const int32_t n_input, - const int32_t n_output, - const int32_t n_cell, - const arm_nn_activation_type activation_type, - int16_t *gate); - -/** - * @brief Update cell state for a single LSTM iteration step, int8x8_16 version. - * @param[in] n_block total number of cells for all batches - * @param[in] cell_state_scale Scaling factor of cell state - * @param[in] cell_state Input/output vector, size n_batch*n_cell - * @param[in] input_gate Input vector of size n_block - * @param[in] forget_gate Input/scratch vector of size n_block, always modified - * @param[in] cell_gate Input vector of size, n_block */ -void arm_nn_lstm_update_cell_state_s16(const int32_t n_block, - const int32_t cell_state_scale, - int16_t *cell_state, - const int16_t *input_gate, - const int16_t *forget_gate, - const int16_t *cell_gate); +arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in, + const int8_t *hidden_in, + int8_t *hidden_out, + const cmsis_nn_lstm_params *params, + cmsis_nn_lstm_context *buffers, + const int32_t batch_offset); /** - * @brief Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version. - * - * @param[in] n_batch The number of distinct vectors in each array - * @param[in] n_cell Number of cells - * @param[in,out] cell_state Cell state, size n_batch*n_cell - * @param[in] cell_state_scale Scaling of cell_state - * @param[in] output_gate Output gate - * @param[in] hidden_scale Effective scaling of cell_state .* output_gate - * @param[in] hidden_offset Zero point for cell_state .* output_gate - * @param[out] output_state Output state - * @param[in] cell_gate_scratch Scratch buffer + * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version. + * + * @param[in] data_in Data input pointer + * @param[in] hidden_in Hidden state/ recurrent input pointer + * @param[in] gate_data Struct containing all information about the gate caluclation, see + * arm_nn_types. + * @param[in] params Struct containing all information about the lstm_operation, see + * arm_nn_types + * @param[out] output Hidden state/ recurrent output pointer + * @param[in] batch_offset Number of timesteps between consecutive batches, see + * arm_nn_lstm_step_s8. + * @return The function returns ARM_CMSIS_NN_SUCCESS */ -void arm_nn_lstm_update_output_s8_s16(const int n_batch, - const int n_cell, - int16_t *cell_state, - const int32_t cell_state_scale, - const int16_t *output_gate, - const cmsis_nn_scaling hidden_scale, - const int32_t hidden_offset, - int8_t *output_state, - int16_t *cell_gate_scratch); +arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in, + const int8_t *hidden_in, + const cmsis_nn_lstm_gate *gate_data, + const cmsis_nn_lstm_params *params, + int16_t *output, + const int32_t batch_offset); /** * @brief The result of the multiplication is accumulated to the passed result buffer. * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent * from each other). * - * @param[in] lhs_in Batched vector - * @param[in] rhs_in Weights - input matrix (H(Rows)xW(Columns)) - * @param[in] bias Bias vector + * @param[in] lhs Batched vector + * @param[in] rhs Weights - input matrix (H(Rows)xW(Columns)) + * @param[in] effective_bias Bias + lhs_offset * kernel_sum term precalculated into a constant vector. * @param[out] dst Output - * @param[in] dst_offset Output offset * @param[in] dst_multiplier Multiplier for quantization * @param[in] dst_shift Shift for quantization * @param[in] rhs_cols Vector/matarix column length * @param[in] rhs_rows Row count of matrix - * @param[in] batch Batch size + * @param[in] batches Batch size + * @param[in] batch_offset Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s8. Note + that the output is always stored with sequential batches. + * @return The function returns ARM_CMSIS_NN_SUCCESS + */ -void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, - const int8_t *rhs_in, - const int32_t *bias, - int16_t *dst, - const int32_t dst_offset, - const int32_t dst_multiplier, - const int32_t dst_shift, - const int32_t rhs_cols, - const int32_t rhs_rows, - const int32_t batch); +arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs, + const int8_t *rhs, + const int32_t *effective_bias, + int16_t *dst, + const int32_t dst_multiplier, + const int32_t dst_shift, + const int32_t rhs_cols, + const int32_t rhs_rows, + const int32_t batches, + const int32_t batch_offset); /** * @brief s16 elementwise multiplication with s8 output @@ -1598,7 +1533,10 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, * @param[in] out_offset output offset * @param[in] out_mult output multiplier * @param[in] out_shift output shift - * @param[in] block_size number of samples + * @param[in] block_size number of samples per batch + * @param[in] batch_size number of samples per batch + * @param[in] batch_offset Number of timesteps between consecutive batches in output, see + * arm_nn_lstm_step_s8. Note that it is assumed that the input is stored with sequential batches. * @return The function returns ARM_CMSIS_NN_SUCCESS * * @details Supported framework: TensorFlow Lite micro @@ -1609,7 +1547,38 @@ arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect, const int32_t out_offset, const int32_t out_mult, const int32_t out_shift, - const int32_t block_size); + const int32_t block_size, + const int32_t batch_size, + const int32_t batch_offset); + +/** + * @brief s16 elementwise multiplication. The result of the multiplication is accumulated to the passed result buffer. + * @param[in] input_1_vect pointer to input vector 1 + * @param[in] input_2_vect pointer to input vector 2 + * @param[in] input_1_offset offset for input 1. Not used. + * @param[in] input_2_offset offset for input 2. Not used. + * @param[in,out] output pointer to output vector + * @param[in] out_offset output offset. Not used. + * @param[in] out_mult output multiplier + * @param[in] out_shift output shift + * @param[in] out_activation_min minimum value to clamp output to. Min: -32768 + * @param[in] out_activation_max maximum value to clamp output to. Max: 32767 + * @param[in] block_size number of samples + * @return The function returns ARM_CMSIS_NN_SUCCESS + * + * @details Supported framework: TensorFlow Lite micro + */ +arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect, + const int16_t *input_2_vect, + const int32_t input_1_offset, + const int32_t input_2_offset, + int16_t *output, + const int32_t out_offset, + const int32_t out_mult, + const int32_t out_shift, + const int32_t out_activation_min, + const int32_t out_activation_max, + const int32_t block_size); #ifdef __cplusplus } diff --git a/Source/ActivationFunctions/arm_nn_activation_s16.c b/Source/ActivationFunctions/arm_nn_activation_s16.c index 99c6dbf6..bb05ae4b 100644 --- a/Source/ActivationFunctions/arm_nn_activation_s16.c +++ b/Source/ActivationFunctions/arm_nn_activation_s16.c @@ -1,5 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2010-2020, 2022 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2010-2020, 2022, 2024 Arm Limited and/or its affiliates + * * * SPDX-License-Identifier: Apache-2.0 * @@ -18,11 +19,11 @@ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library - * Title: arm_nn_activations_q15.c + * Title: arm_nn_activation_s16.c * Description: Q15 neural network activation function using direct table look-up * - * $Date: 8 September 2022 - * $Revision: V.1.0.0 + * $Date: 19 January 2024 + * $Revision: V.2.0.0 * * Target Processor: Cortex-M cores * @@ -47,11 +48,11 @@ * */ -void arm_nn_activation_s16(const int16_t *input, - int16_t *output, - const uint16_t size, - const uint16_t left_shift, - const arm_nn_activation_type type) +arm_cmsis_nn_status arm_nn_activation_s16(const int16_t *input, + int16_t *output, + const int32_t size, + const int32_t left_shift, + const arm_nn_activation_type type) { uint32_t abs_input_shift, max_saturation; switch (type) @@ -67,18 +68,17 @@ void arm_nn_activation_s16(const int16_t *input, break; } + const int32_t input_multiplier = (left_shift < 0) ? 3 : 3 << left_shift; + const int32_t abs_left_shift = (left_shift < 0) ? -left_shift : 0; + const int32_t rounding = (abs_left_shift > 0) ? 1 << (abs_left_shift - 1) : 0; // Use the LUT for sigmoid and take into account, that // tanh(x) = 2*sigmoid(2*x) - 1 - int32_t input_multiplier = ((int32_t)3) << left_shift; for (int i = 0; i < size; ++i, input++, output++) { - int32_t input_data = ((*input) * input_multiplier); - - uint32_t abs_input_data = input_data > 0 ? input_data : -input_data; - - uint32_t uh = abs_input_data >> abs_input_shift; - + const int32_t input_data = ((*input) * input_multiplier + rounding) >> abs_left_shift; + const uint32_t abs_input_data = input_data > 0 ? input_data : -input_data; + const uint32_t uh = abs_input_data >> abs_input_shift; uint32_t result; if (uh >= 255) @@ -87,8 +87,8 @@ void arm_nn_activation_s16(const int16_t *input, } else { - uint32_t ua = sigmoid_table_uint16[uh]; - uint32_t ub = sigmoid_table_uint16[uh + 1]; + const uint32_t ua = sigmoid_table_uint16[uh]; + const uint32_t ub = sigmoid_table_uint16[uh + 1]; uint32_t ut; if (type == ARM_SIGMOID) { @@ -112,6 +112,8 @@ void arm_nn_activation_s16(const int16_t *input, } *output = (int16_t)result; } + + return ARM_CMSIS_NN_SUCCESS; } /** diff --git a/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c b/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c new file mode 100644 index 00000000..28277922 --- /dev/null +++ b/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c @@ -0,0 +1,167 @@ +/* + * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* ---------------------------------------------------------------------- + * Project: CMSIS NN Library + * Title: arm_elementwise_mul_acc_s16 + * Description: Accumulative element wise multiplication + * + * $Date: 19 January 2024 + * $Revision: V.1.0.0 + * + * Target : Arm(R) M-Profile Architecture + * + * -------------------------------------------------------------------- */ + +#include "arm_nnfunctions.h" +#include "arm_nnsupportfunctions.h" + +/** + * @ingroup Public + */ + +/** + * @addtogroup groupElementwise + * @{ + */ + +/** + * @brief s16 element wise accumulative multiplication of two vectors + * + * @note Refer header file for details. + * + */ +arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect, + const int16_t *input_2_vect, + const int32_t input_1_offset, + const int32_t input_2_offset, + int16_t *output, + const int32_t out_offset, + const int32_t out_mult, + const int32_t out_shift, + const int32_t out_activation_min, + const int32_t out_activation_max, + const int32_t block_size) +{ + (void)input_1_offset; + (void)input_2_offset; + (void)out_offset; + int32_t loop_count; + + const int32_t activation_max = (out_activation_max > 0) ? out_activation_max : NN_Q15_MAX; + const int32_t activation_min = (out_activation_max > 0) ? out_activation_min : NN_Q15_MIN; + +#if defined(ARM_MATH_MVEI) + + loop_count = block_size; + + while (loop_count > 0) + { + mve_pred16_t pred = vctp32q(loop_count); + + int32x4_t input_1 = vldrhq_z_s32(input_1_vect, pred); + int32x4_t input_2 = vldrhq_z_s32(input_2_vect, pred); + + int32x4_t res_0 = vmulq_s32(input_1, input_2); + + res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift)); + + res_0 = vaddq_s32(res_0, vldrhq_z_s32(output, pred)); + + res_0 = vmaxq_s32(res_0, vdupq_n_s32(activation_min)); + res_0 = vminq_s32(res_0, vdupq_n_s32(activation_max)); + + vstrhq_p_s32(output, res_0, pred); + input_1_vect += 4; + input_2_vect += 4; + + output += 4; + loop_count -= 4; + } + +#else + int32_t input_1; + int32_t input_2; + int32_t mul_res; + int32_t two_halfword_1, two_halfword_2; + int16_t mul_1, mul_2; + loop_count = block_size / 2; + + while (loop_count > 0) + { + two_halfword_1 = arm_nn_read_q15x2_ia(&input_1_vect); + two_halfword_2 = arm_nn_read_q15x2_ia(&input_2_vect); + + #if defined(ARM_MATH_DSP) + mul_res = SMULBB(two_halfword_1, two_halfword_2); + #else + input_1 = (int16_t)(two_halfword_1 & 0xFFFF); + input_2 = (int16_t)(two_halfword_2 & 0xFFFF); + mul_res = input_1 * input_2; + #endif + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift); + mul_res += output[0]; + + mul_res = MAX(mul_res, activation_min); + mul_res = MIN(mul_res, activation_max); + mul_1 = (int16_t)mul_res; + + #if defined(ARM_MATH_DSP) + mul_res = SMULTT(two_halfword_1, two_halfword_2); + #else + input_1 = (int16_t)(two_halfword_1 >> 16); + input_2 = (int16_t)(two_halfword_2 >> 16); + mul_res = input_1 * input_2; + #endif + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift); + mul_res += output[1]; + mul_res = MAX(mul_res, activation_min); + mul_res = MIN(mul_res, activation_max); + mul_2 = (int16_t)mul_res; + + arm_nn_write_q15x2_ia(&output, PACK_Q15x2_32x1(mul_1, mul_2)); + + loop_count--; + } + loop_count = block_size & 0x1; + + while (loop_count > 0) + { + + input_1 = *input_1_vect++; + input_2 = *input_2_vect++; + + mul_res = input_1 * input_2; + + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift); + mul_res += output[0]; + + mul_res = MAX(mul_res, activation_min); + mul_res = MIN(mul_res, activation_max); + + *output++ = (int16_t)mul_res; + + loop_count--; + } +#endif // #if defined(ARM_MATH_MVEI) + return ARM_CMSIS_NN_SUCCESS; +} + +/** + * @} end of Doxygen group + */ diff --git a/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c b/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c index 94f19dfd..f53c2e5b 100644 --- a/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c +++ b/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -22,7 +22,7 @@ * Description: Elementwise multiplication of 16 bit input with 8 bit output * * $Date: 20 January 2023 - * $Revision: V.1.2.0 + * $Revision: V.2.0.0 * * Target : Arm(R) M-Profile Architecture * @@ -51,70 +51,84 @@ arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect, const int32_t out_offset, const int32_t out_mult, const int32_t out_shift, - const int32_t block_size) + const int32_t block_size, + const int32_t batch_size, + const int32_t batch_offset) { - int32_t loop_count = block_size; + for (int i = 0; i < batch_size; i++) + { + int32_t loop_count = block_size; #if defined(ARM_MATH_MVEI) - while (loop_count > 0) - { - mve_pred16_t pred = vctp32q(loop_count); + const int16_t *input_1_ptr = input_1_vect; + const int16_t *input_2_ptr = input_2_vect; + int8_t *output_ptr = output; - int32x4_t input_1 = vldrhq_z_s32(input_1_vect, pred); - int32x4_t input_2 = vldrhq_z_s32(input_2_vect, pred); + while (loop_count > 0) + { + mve_pred16_t pred = vctp32q(loop_count); - int32x4_t res_0 = vmulq_s32(input_1, input_2); + int32x4_t input_1 = vldrhq_z_s32(input_1_ptr, pred); + int32x4_t input_2 = vldrhq_z_s32(input_2_ptr, pred); - res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift)); - res_0 = vaddq_n_s32(res_0, out_offset); + int32x4_t res_0 = vmulq_s32(input_1, input_2); - res_0 = vmaxq_s32(res_0, vdupq_n_s32(NN_Q7_MIN)); - res_0 = vminq_s32(res_0, vdupq_n_s32(NN_Q7_MAX)); + res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift)); + res_0 = vaddq_n_s32(res_0, out_offset); - vstrbq_p_s32(output, res_0, pred); - input_1_vect += 4; - input_2_vect += 4; + res_0 = vmaxq_s32(res_0, vdupq_n_s32(NN_Q7_MIN)); + res_0 = vminq_s32(res_0, vdupq_n_s32(NN_Q7_MAX)); - output += 4; - loop_count -= 4; - } + vstrbq_p_s32(output_ptr, res_0, pred); + input_1_ptr += 4; + input_2_ptr += 4; + + output_ptr += 4; + loop_count -= 4; + } + + input_1_vect += block_size; + input_2_vect += block_size; + output += block_size; #else #if defined(ARM_MATH_DSP) - while (loop_count > 1) - { - int32_t input_1 = arm_nn_read_q15x2_ia(&input_1_vect); - int32_t input_2 = arm_nn_read_q15x2_ia(&input_2_vect); + while (loop_count > 1) + { + int32_t input_1 = arm_nn_read_q15x2_ia(&input_1_vect); + int32_t input_2 = arm_nn_read_q15x2_ia(&input_2_vect); - int32_t mul_res = SMULBB(input_1, input_2); - mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; - mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); - int32_t mul = (int16_t)(mul_res & 0xFF); + int32_t mul_res = SMULBB(input_1, input_2); + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; + mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); + int32_t mul = (int16_t)(mul_res & 0xFF); - mul_res = SMULTT(input_1, input_2); - mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; - mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); - mul |= (int16_t)mul_res << 8; + mul_res = SMULTT(input_1, input_2); + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; + mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); + mul |= (int16_t)mul_res << 8; - arm_nn_write_s8x2_ia(&output, mul); - loop_count -= 2; - } + arm_nn_write_s8x2_ia(&output, mul); + loop_count -= 2; + } #endif - for (int i = 0; i < loop_count; i++) - { - /* C = A * B */ - int32_t mul_res = input_1_vect[i] * input_2_vect[i]; - mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; + for (int j = 0; j < loop_count; j++, input_1_vect++, input_2_vect++, output++) + { + /* C = A * B */ + int32_t mul_res = (*input_1_vect) * (*input_2_vect); + mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset; - mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); + mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN); - output[i] = (int8_t)mul_res; - } + *output = (int8_t)mul_res; + } #endif + output += (batch_offset - 1) * block_size; + } return ARM_CMSIS_NN_SUCCESS; } /** diff --git a/Source/FullyConnectedFunctions/arm_vector_sum_s8.c b/Source/FullyConnectedFunctions/arm_vector_sum_s8.c index 0120ba1d..a2564da5 100644 --- a/Source/FullyConnectedFunctions/arm_vector_sum_s8.c +++ b/Source/FullyConnectedFunctions/arm_vector_sum_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -21,8 +21,8 @@ * Title: arm_vector_sum_s8 * Description: Generic function for calculating vector sums * - * $Date: 5 September 2023 - * $Revision: V.1.0.0 + * $Date: 26 January 2024 + * $Revision: V.2.0.0 * * Target : Arm(R) M-Profile Architecture * @@ -49,93 +49,122 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf, const int32_t vector_cols, const int32_t vector_rows, - const int8_t *vector_data) + const int8_t *vector_data, + const int32_t lhs_offset, + const int32_t *bias_data) { #if defined(ARM_MATH_MVEI) - const int32_t row_loop_cnt = vector_rows / 4; - + const int32_t row_loop_cnt = vector_rows / 5; for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++) { const int32_t col_loop_cnt = (vector_cols + 15) / 16; - const int8_t *vector_0 = vector_data; const int8_t *vector_1 = vector_data + vector_cols; const int8_t *vector_2 = vector_data + 2 * vector_cols; const int8_t *vector_3 = vector_data + 3 * vector_cols; - + const int8_t *vector_4 = vector_data + 4 * vector_cols; int32_t vector_sum_0 = 0; int32_t vector_sum_1 = 0; int32_t vector_sum_2 = 0; int32_t vector_sum_3 = 0; - + int32_t vector_sum_4 = 0; uint32_t col_cnt = (uint32_t)vector_cols; - for (int i = 0; i < col_loop_cnt; i++) { mve_pred16_t p = vctp8q(col_cnt); col_cnt -= 16; - const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p); vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0); - const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p); vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1); - const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p); vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2); - const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p); vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3); - + const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p); + vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4); vector_0 += 16; vector_1 += 16; vector_2 += 16; vector_3 += 16; + vector_4 += 16; + } + vector_data += 5 * vector_cols; + if (lhs_offset) + { + vector_sum_0 *= lhs_offset; + vector_sum_1 *= lhs_offset; + vector_sum_2 *= lhs_offset; + vector_sum_3 *= lhs_offset; + vector_sum_4 *= lhs_offset; + } + if (bias_data) + { + vector_sum_0 += *bias_data++; + vector_sum_1 += *bias_data++; + vector_sum_2 += *bias_data++; + vector_sum_3 += *bias_data++; + vector_sum_4 += *bias_data++; } - vector_data += 4 * vector_cols; - vector_sum_buf[0] = vector_sum_0; vector_sum_buf[1] = vector_sum_1; vector_sum_buf[2] = vector_sum_2; vector_sum_buf[3] = vector_sum_3; - vector_sum_buf += 4; + vector_sum_buf[4] = vector_sum_4; + vector_sum_buf += 5; } - - const int32_t loop_cnt = vector_rows % 4; - + const int32_t loop_cnt = vector_rows % 5; for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++) { const int32_t col_loop_cnt = (vector_cols + 15) / 16; - const int8_t *vector_0 = vector_data; - int32_t vector_sum_0 = 0; - uint32_t col_cnt = (uint32_t)vector_cols; - for (int i = 0; i < col_loop_cnt; i++) { mve_pred16_t p = vctp8q(col_cnt); col_cnt -= 16; - const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p); vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0); - vector_0 += 16; } vector_data += vector_cols; - + if (lhs_offset) + { + vector_sum_0 *= lhs_offset; + } + if (bias_data) + { + vector_sum_0 += *bias_data++; + } vector_sum_buf[i_row_loop_cnt] = vector_sum_0; } - return (ARM_CMSIS_NN_SUCCESS); #else - (void)vector_sum_buf; - (void)vector_rows; - (void)vector_cols; - (void)vector_data; - return (ARM_CMSIS_NN_NO_IMPL_ERROR); + if (bias_data) + { + memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t)); + } + else + { + memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t)); + } + + if (lhs_offset) + { + for (int i = 0; i < vector_rows; i++) + { + int32_t sum = 0; + for (int j = 0; j < vector_cols; j++) + { + sum += *vector_data++; + } + *vector_sum_buf++ += sum * lhs_offset; + } + } + return (ARM_CMSIS_NN_SUCCESS); + #endif } diff --git a/Source/LSTMFunctions/CMakeLists.txt b/Source/LSTMFunctions/CMakeLists.txt index e20d0963..eed27265 100644 --- a/Source/LSTMFunctions/CMakeLists.txt +++ b/Source/LSTMFunctions/CMakeLists.txt @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2019-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2019-2022, 2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -16,5 +16,5 @@ # limitations under the License. # -file(GLOB SRC "./*_s16.c") -target_sources(cmsis-nn PRIVATE ${SRC}) +file(GLOB SRC_S8 "./*_s8.c") +target_sources(cmsis-nn PRIVATE ${SRC_S8}) diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c new file mode 100644 index 00000000..c2868b3a --- /dev/null +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c @@ -0,0 +1,80 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* ---------------------------------------------------------------------- + * Project: CMSIS NN Library + * Title: arm_lstm_unidirectional_s8.c + * Description: S8 LSTM function with S16 gate output + * + * $Date: 19 January 2024 + * $Revision: V.1.0.0 + * + * Target Processor: Cortex-M processors + * + * -------------------------------------------------------------------- */ + +#include "arm_nnfunctions.h" +#include "arm_nnsupportfunctions.h" +/** + * @ingroup Public + */ + +/** + * @addtogroup LSTM + * @{ + */ + +/* + * S8 LSTM function for TensorFlow Lite with S16 gate output + * + * Refer to header file for details. + * + */ + +arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input, + int8_t *output, + const cmsis_nn_lstm_params *params, + cmsis_nn_lstm_context *buffers) +{ + + int8_t *hidden_in = NULL; + memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); + const int32_t batch_offset = (params->time_major) ? 1 : params->time_steps; + + for (int t = 0; t < params->time_steps; t++) + { + const int8_t *data_in = input + (t * params->batch_size * params->input_size); + int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size); + + arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, batch_offset); + + if (status != ARM_CMSIS_NN_SUCCESS) + { + return status; + } + + // Output is used as recurrent input/hidden state for the next timestep. + hidden_in = &hidden_out[0]; + } + + return ARM_CMSIS_NN_SUCCESS; +} + +/** + * @} end of LSTM group + */ diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c deleted file mode 100644 index 6f658869..00000000 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c +++ /dev/null @@ -1,184 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* ---------------------------------------------------------------------- - * Project: CMSIS NN Library - * Title: arm_lstm_unidirectional_s16_s8.c - * Description: S8 LSTM function with S16 gate output - * - * $Date: 4 November 2022 - * $Revision: V.1.0.0 - * - * Target Processor: Cortex-M processors - * - * -------------------------------------------------------------------- */ - -#include "arm_nnfunctions.h" -#include "arm_nnsupportfunctions.h" - -/** - * @ingroup Public - */ - -/** - * @addtogroup LSTM - * @{ - */ - -/* - * S8 LSTM function for TensorFlow Lite with S16 gate output - * - * Refer to header file for details. - * - */ - -#include "arm_nnfunctions.h" -#include "arm_nnsupportfunctions.h" - -/* - * LSTM unidirectional function with 8 bit input and output and 16 bit weights - * - * Refer header file for details. - * - */ -arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers, - const int8_t *input_data, - const cmsis_nn_lstm_dims *lstm_dims, - const int8_t *in_to_in_weights, - const int8_t *in_to_forget_weights, - const int8_t *in_to_cell_weights, - const int8_t *in_to_out_weights, - const int8_t *recurrent_to_in_weights, - const int8_t *recurrent_to_forget_weights, - const int8_t *recurrent_to_cell_weights, - const int8_t *recurrent_to_out_weights, - const int16_t *cell_to_in_weights, - const int16_t *cell_to_forget_weights, - const int16_t *cell_to_out_weights, - const int8_t *projection_weights, - const cmsis_nn_lstm_params *lstm, - int8_t *output_state, - int16_t *cell_state, - int8_t *output_data) -{ - (void)cell_to_in_weights; - (void)cell_to_forget_weights; - (void)cell_to_out_weights; - - const int32_t num_batch = lstm_dims->num_batches; - const int32_t num_input = lstm_dims->num_inputs; - const int32_t max_time = lstm_dims->max_time; - - const int32_t num_output = lstm_dims->num_outputs; - const int32_t out_batch_leading_dim = num_output; - - // num_cell = num_output is considered in the code under the assumption that projection is NULL. - const int32_t num_cell = num_output; - - if (projection_weights != NULL) - { - return ARM_CMSIS_NN_ARG_ERROR; - } - - if (lstm->i2f_effective_bias == NULL || lstm->i2c_effective_bias == NULL || lstm->i2o_effective_bias == NULL) - { - return ARM_CMSIS_NN_ARG_ERROR; - } - - if (lstm->r2f_effective_bias == NULL || lstm->r2c_effective_bias == NULL || lstm->r2o_effective_bias == NULL) - { - return ARM_CMSIS_NN_ARG_ERROR; - } - - if (lstm->i2i_effective_bias == NULL || lstm->r2i_effective_bias == NULL) - { - return ARM_CMSIS_NN_ARG_ERROR; - } - - if (lstm->time_major) - { - const int32_t in_step = num_batch * num_input; - const int32_t out_step = num_batch * out_batch_leading_dim; - for (int i_max_time = 0; i_max_time < max_time; i_max_time++) - { - arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + i_max_time * in_step, - in_to_in_weights, - in_to_forget_weights, - in_to_cell_weights, - in_to_out_weights, - recurrent_to_in_weights, - recurrent_to_forget_weights, - recurrent_to_cell_weights, - recurrent_to_out_weights, - lstm, - num_batch, - num_cell, - num_input, - num_output, - output_state, - cell_state, - output_data + i_max_time * out_step, - scratch_buffers); - if (status != ARM_CMSIS_NN_SUCCESS) - { - return status; - } - } - } - else - { - for (int i_num_batch = 0; i_num_batch < num_batch; i_num_batch++) - { - const int32_t in_step = num_input; - const int32_t out_step = out_batch_leading_dim; - for (int i_max_time = 0; i_max_time < max_time; i_max_time++) - { - const int32_t time_offset = i_num_batch * max_time + i_max_time; - - arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + time_offset * in_step, - in_to_in_weights, - in_to_forget_weights, - in_to_cell_weights, - in_to_out_weights, - recurrent_to_in_weights, - recurrent_to_forget_weights, - recurrent_to_cell_weights, - recurrent_to_out_weights, - lstm, - /*num_batch=*/1, - num_cell, - num_input, - num_output, - output_state + i_num_batch * out_batch_leading_dim, - cell_state + i_num_batch * num_cell, - output_data + time_offset * out_step, - scratch_buffers); - if (status != ARM_CMSIS_NN_SUCCESS) - { - return status; - } - } - } - } - - return ARM_CMSIS_NN_SUCCESS; -} - -/** - * @} end of LSTM group - */ diff --git a/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c index 37c7d6d9..b1bf9550 100644 --- a/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c +++ b/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -21,8 +21,8 @@ * Title: arm_nn_lstm_calculate_gate_s8_s16.c * Description: Update single gate for an incremental step of LSTM function. * - * $Date: 8 September 2022 - * $Revision: V.1.0.0 + * $Date: 19 January 2024 + * $Revision: V.2.0.0 * * Target Processor: Cortex-M cores * @@ -31,7 +31,6 @@ #include "arm_nn_tables.h" #include "arm_nnfunctions.h" #include "arm_nnsupportfunctions.h" - /** * @ingroup groupSupport */ @@ -52,48 +51,45 @@ * Calculates a single LSTM gate, int8x8_16 version. * Refer to header file for details */ -void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input, - const int8_t *input_to_gate_weights, - const int32_t *input_to_gate_bias, - const cmsis_nn_scaling input_to_gate_scaling, - const int8_t *output_state, - const int8_t *recurrent_to_gate_weights, - const int32_t *recurrent_to_gate_bias, - const cmsis_nn_scaling recurrent_to_gate, - const int32_t n_batch, - const int32_t n_input, - const int32_t n_output, - const int32_t n_cell, - const arm_nn_activation_type activation_type, - int16_t *gate) +arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in, + const int8_t *hidden_in, + const cmsis_nn_lstm_gate *gate, + const cmsis_nn_lstm_params *params, + int16_t *output, + const int32_t batch_offset) { - const int32_t n_block = n_batch * n_cell; - memset(gate, 0, n_block * sizeof(int16_t)); - arm_nn_vec_mat_mul_result_acc_s8(input, - input_to_gate_weights, - input_to_gate_bias, - gate, - 0, - input_to_gate_scaling.multiplier, - input_to_gate_scaling.shift, - n_input, - n_cell, - n_batch); + memset(output, 0, params->hidden_size * params->batch_size * sizeof(int16_t)); + + arm_nn_vec_mat_mul_result_acc_s8_s16(data_in, + gate->input_weights, + gate->input_effective_bias, + output, + gate->input_multiplier, + gate->input_shift, + params->input_size, + params->hidden_size, + params->batch_size, + batch_offset); - arm_nn_vec_mat_mul_result_acc_s8(output_state, - recurrent_to_gate_weights, - recurrent_to_gate_bias, - gate, - 0, - recurrent_to_gate.multiplier, - recurrent_to_gate.shift, - n_output, - n_cell, - n_batch); + if (hidden_in) + { + arm_nn_vec_mat_mul_result_acc_s8_s16(hidden_in, + gate->hidden_weights, + gate->hidden_effective_bias, + output, + gate->hidden_multiplier, + gate->hidden_shift, + params->hidden_size, + params->hidden_size, + params->batch_size, + batch_offset); + } - arm_nn_activation_s16(gate, gate, n_block, 0, activation_type); + arm_nn_activation_s16(output, output, params->hidden_size * params->batch_size, 0, gate->activation_type); + + return ARM_CMSIS_NN_SUCCESS; } /** * @} end of supportLSTM group - */ + */ \ No newline at end of file diff --git a/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c b/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c new file mode 100644 index 00000000..4b25081a --- /dev/null +++ b/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c @@ -0,0 +1,110 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* ---------------------------------------------------------------------- + * Project: CMSIS NN Library + * Title: arm_nn_lstm_step_s8.c + * Description: Update LSTM function for a single iteration step. + * + * $Date: 19 January 2024 + * $Revision: V.1.0.0 + * + * Target : Arm(R) M-Profile Architecture + * + * -------------------------------------------------------------------- */ +#include "arm_nnfunctions.h" +#include "arm_nnsupportfunctions.h" +/** + * @ingroup groupSupport + */ + +/** + * @addtogroup supportLSTM + * @{ + */ + +/* + * Calculate the output state tensor of an LSTM step, s8 input/output/weights and s16 internal buffers version. + * Refer to header file for details. + */ +arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in, + const int8_t *hidden_in, + int8_t *hidden_out, + const cmsis_nn_lstm_params *params, + cmsis_nn_lstm_context *buffers, + const int32_t batch_offset) +{ + int16_t *forget_gate = buffers->temp1; + int16_t *input_gate = buffers->temp1; + int16_t *cell_gate = buffers->temp2; + int16_t *output_gate = buffers->temp1; + int16_t *hidden_temp = buffers->temp2; + + int16_t *cell_state = buffers->cell_state; + + arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->forget_gate, params, forget_gate, batch_offset); + + // Calculate first term of cell state in place early to maximise reuse of scratch-buffers + arm_elementwise_mul_s16(forget_gate, + cell_state, + 0, + 0, + cell_state, + 0, + params->forget_to_cell_multiplier, + params->forget_to_cell_shift, + NN_Q15_MIN, + NN_Q15_MAX, + params->hidden_size * params->batch_size); + + arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->input_gate, params, input_gate, batch_offset); + arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->cell_gate, params, cell_gate, batch_offset); + + // Reminder of cell state calculation, multiply and add to previous result. + arm_elementwise_mul_acc_s16(forget_gate, + cell_gate, + 0, + 0, + cell_state, + 0, + params->input_to_cell_multiplier, + params->input_to_cell_shift, + -params->cell_clip, + params->cell_clip, + params->hidden_size * params->batch_size); + + arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->output_gate, params, output_gate, batch_offset); + + // Calculate hidden state directly to output. + arm_nn_activation_s16( + cell_state, hidden_temp, params->hidden_size * params->batch_size, params->cell_scale_power + 12, ARM_TANH); + arm_elementwise_mul_s16_s8(output_gate, + hidden_temp, + hidden_out, + params->output_offset, + params->output_multiplier, + params->output_shift, + params->hidden_size, + params->batch_size, + batch_offset); + + return ARM_CMSIS_NN_SUCCESS; +} +/** + * @} end of supportLSTM group + */ diff --git a/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c deleted file mode 100644 index 41e3071f..00000000 --- a/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c +++ /dev/null @@ -1,154 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* ---------------------------------------------------------------------- - * Project: CMSIS NN Library - * Title: arm_nn_lstm_step_s8_s16.c - * Description: Update LSTM function for a single iteration step. - * - * $Date: 9 Februari 2023 - * $Revision: V.1.1.0 - * - * Target : Arm(R) M-Profile Architecture - * - * -------------------------------------------------------------------- */ -#include "arm_nnsupportfunctions.h" -/** - * @ingroup groupSupport - */ - -/** - * @addtogroup supportLSTM - * @{ - */ - -/* - * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version. - * Refer to header file for details. - */ -arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input, - const int8_t *input_to_input_weight, - const int8_t *input_to_forget_weight, - const int8_t *input_to_cell_weight, - const int8_t *input_to_output_weight, - const int8_t *recurrent_to_input_weight, - const int8_t *recurrent_to_forget_weight, - const int8_t *recurrent_to_cell_weight, - const int8_t *recurrent_to_output_weight, - const cmsis_nn_lstm_params *lstm, - const int n_batch, - const int n_cell, - const int n_input, - const int n_output, - int8_t *output_state, - int16_t *cell_state, - int8_t *output, - cmsis_nn_lstm_context *scratch_buffers) -{ - const int32_t n_block = n_batch * n_cell; - - // Calculate the input gate - arm_nn_lstm_calculate_gate_s8_s16(input, - input_to_input_weight, - lstm->i2i_effective_bias, - lstm->input_to_input_scaling, - output_state, - recurrent_to_input_weight, - lstm->r2i_effective_bias, - lstm->recurrent_to_input_scaling, - n_batch, - n_input, - n_output, - n_cell, - ARM_SIGMOID, - scratch_buffers->input_gate); - - // Calculate the forget gate - arm_nn_lstm_calculate_gate_s8_s16(input, - input_to_forget_weight, - lstm->i2f_effective_bias, - lstm->input_to_forget_scaling, - output_state, - recurrent_to_forget_weight, - lstm->r2f_effective_bias, - lstm->recurrent_to_forget_scaling, - n_batch, - n_input, - n_output, - n_cell, - ARM_SIGMOID, - scratch_buffers->forget_gate); - - // Calculate the cell update gate - arm_nn_lstm_calculate_gate_s8_s16(input, - input_to_cell_weight, - lstm->i2c_effective_bias, - lstm->input_to_cell_scaling, - output_state, - recurrent_to_cell_weight, - lstm->r2c_effective_bias, - lstm->recurrent_to_cell_scaling, - n_batch, - n_input, - n_output, - n_cell, - ARM_TANH, - scratch_buffers->cell_gate); - - // Update the cell state - arm_nn_lstm_update_cell_state_s16(n_block, - lstm->cell_state_shift, - cell_state, - scratch_buffers->input_gate, - scratch_buffers->forget_gate, - scratch_buffers->cell_gate); - - // Calculate the output gate - arm_nn_lstm_calculate_gate_s8_s16(input, - input_to_output_weight, - lstm->i2o_effective_bias, - lstm->input_to_output_scaling, - output_state, - recurrent_to_output_weight, - lstm->r2o_effective_bias, - lstm->recurrent_to_output_scaling, - n_batch, - n_input, - n_output, - n_cell, - ARM_SIGMOID, - scratch_buffers->output_gate); - - // Update the output state - arm_nn_lstm_update_output_s8_s16(n_batch, - n_cell, - cell_state, - lstm->cell_state_shift, - scratch_buffers->output_gate, - lstm->hidden_scaling, - lstm->hidden_offset, - output_state, - scratch_buffers->input_gate); - - arm_memcpy_s8(output, output_state, n_batch * n_output * sizeof(int8_t)); - - return ARM_CMSIS_NN_SUCCESS; -} -/** - * @} end of supportLSTM group - */ diff --git a/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c deleted file mode 100644 index b142773e..00000000 --- a/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c +++ /dev/null @@ -1,124 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* ---------------------------------------------------------------------- - * Project: CMSIS NN Library - * Title: arm_nn_lstm_update_cell_state_s16.c - * Description: Update cell state for an incremental step of LSTM function. - * - * $Date: 20 January 2023 - * $Revision: V.1.2.0 - * - * Target : Arm(R) M-Profile Architecture - * - * -------------------------------------------------------------------- */ - -#include "arm_nnsupportfunctions.h" -/** - * @ingroup groupSupport - */ - -/** - * @addtogroup supportLSTM - * @{ - */ - -/* - * Update cell state for a single LSTM iteration step, int8x8_16 version. - * - * Refer to header file for more details - */ -void arm_nn_lstm_update_cell_state_s16(const int32_t n_block, - const int32_t cell_state_scale, - int16_t *cell_state, - const int16_t *input_gate, - const int16_t *forget_gate, - const int16_t *cell_gate) -{ - const int32_t cell_scale = 30 + cell_state_scale; - int32_t loop_count = n_block; - -#if defined(ARM_MATH_MVEI) - - while (loop_count > 0) - { - mve_pred16_t p = vctp32q(loop_count); - loop_count -= 4; - - int32x4_t res_1 = vmulq_s32(vldrhq_z_s32(cell_state, p), vldrhq_z_s32(forget_gate, p)); - forget_gate += 4; - res_1 = arm_divide_by_power_of_two_mve(res_1, 15); - int32x4_t res_2 = vmulq_s32(vldrhq_z_s32(input_gate, p), vldrhq_z_s32(cell_gate, p)); - input_gate += 4; - cell_gate += 4; - - res_2 = arm_divide_by_power_of_two_mve(res_2, cell_scale); - res_1 += res_2; - - res_1 = vmaxq_s32(res_1, vdupq_n_s32(NN_Q15_MIN)); - res_1 = vminq_s32(res_1, vdupq_n_s32(NN_Q15_MAX)); - - vstrhq_p_s32(cell_state, res_1, p); - cell_state += 4; - } -#else - #if defined(ARM_MATH_DSP) - while (loop_count > 1) - { - int32_t cell_state_01 = arm_nn_read_s16x2(cell_state); - int32_t forget_gate_01 = arm_nn_read_q15x2_ia(&forget_gate); - - int32_t value_00 = SMULBB(cell_state_01, forget_gate_01); - int32_t value_01 = SMULTT(cell_state_01, forget_gate_01); - value_00 = arm_nn_divide_by_power_of_two(value_00, 15); - value_01 = arm_nn_divide_by_power_of_two(value_01, 15); - - int32_t input_gate_01 = arm_nn_read_q15x2_ia(&input_gate); - int32_t cell_gate_01 = arm_nn_read_q15x2_ia(&cell_gate); - - int32_t value_10 = SMULBB(input_gate_01, cell_gate_01); - int32_t value_11 = SMULTT(input_gate_01, cell_gate_01); - - value_10 = arm_nn_divide_by_power_of_two(value_10, cell_scale); - value_11 = arm_nn_divide_by_power_of_two(value_11, cell_scale); - - value_00 += value_10; - value_01 += value_11; - - value_00 = CLAMP(value_00, NN_Q15_MAX, NN_Q15_MIN); - value_01 = CLAMP(value_01, NN_Q15_MAX, NN_Q15_MIN); - - arm_nn_write_q15x2_ia(&cell_state, PACK_Q15x2_32x1(value_00, value_01)); - loop_count -= 2; - } - #endif - for (int i = 0; i < loop_count; i++) - { - int32_t value = cell_state[i] * forget_gate[i]; - int32_t value_1 = input_gate[i] * cell_gate[i]; - - value = arm_nn_divide_by_power_of_two(value, 15); - value_1 = arm_nn_divide_by_power_of_two(value_1, cell_scale); - - cell_state[i] = CLAMP(value + value_1, NN_Q15_MAX, NN_Q15_MIN); - } -#endif // #if defined(ARM_MATH_MVEI) -} -/** - * @} end of supportLSTM group - */ diff --git a/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c deleted file mode 100644 index c742f354..00000000 --- a/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c +++ /dev/null @@ -1,81 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* ---------------------------------------------------------------------- - * Project: CMSIS NN Library - * Title: arm_nn_lstm_update_output_s8_s16.c - * Description: Update output gate for an incremental step of LSTM function. - * - * $Date: 13 Februari 2023 - * $Revision: V.2.0.0 - * - * Target : Arm(R) M-Profile Architecture - * - * -------------------------------------------------------------------- */ - -#include "arm_nnfunctions.h" -#include "arm_nnsupportfunctions.h" - -/** - * @ingroup groupSupport - */ - -/** - * @addtogroup supportLSTM - * @{ - */ - -/* - * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version. - * Refer to header files for details - */ -void arm_nn_lstm_update_output_s8_s16(const int n_batch, - const int n_cell, - int16_t *cell_state, - const int32_t cell_state_scale, - const int16_t *output_gate, - const cmsis_nn_scaling hidden_scaling, - const int32_t hidden_offset, - int8_t *output_state, - int16_t *cell_gate_scratch) -{ - const int32_t size = n_batch * n_cell; - - int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3; - if (tanh_input_left_shift < 0) - { - tanh_input_left_shift = -tanh_input_left_shift; - for (int32_t i = 0; i < size; i++) - { - cell_state[i] = cell_state[i] >> tanh_input_left_shift; - } - tanh_input_left_shift = 0; - } - arm_nn_activation_s16(cell_state, cell_gate_scratch, size, tanh_input_left_shift, ARM_TANH); - - arm_elementwise_mul_s16_s8(output_gate, - cell_gate_scratch, - output_state, - hidden_offset, - hidden_scaling.multiplier, - hidden_scaling.shift, - size); -} -/** - * @} end of supportLSTM group - */ diff --git a/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c b/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c similarity index 60% rename from Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c rename to Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c index d1eba6e2..17b9a63f 100644 --- a/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c +++ b/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -18,18 +18,16 @@ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library - * Title: arm_nn_vec_mat_mul_result_acc_s8.c + * Title: arm_nn_vec_mat_mul_result_acc_s8_s16.c * Description: Multiplies a matrix by a vector and accumulate with output. * - * $Date: 20 January 2023 - * $Revision: V.1.2.0 + * $Date: 19 January 2024 + * $Revision: V.2.0.0 * * Target : Arm(R) M-Profile Architecture * * -------------------------------------------------------------------- */ - #include "arm_nnsupportfunctions.h" - /** * @ingroup groupSupport */ @@ -42,41 +40,47 @@ /* * Refer to header file for details. */ -void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, - const int8_t *rhs_in, - const int32_t *bias, - int16_t *dst, - const int32_t dst_offset, - const int32_t dst_multiplier, - const int32_t dst_shift, - const int32_t rhs_cols, - const int32_t rhs_rows, - const int32_t batch) +arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs, + const int8_t *rhs, + const int32_t *effective_bias, + int16_t *dst, + const int32_t dst_multiplier, + const int32_t dst_shift, + const int32_t rhs_cols, + const int32_t rhs_rows, + const int32_t batches, + const int32_t batch_offset) { - for (int i_batch = 0; i_batch < batch; ++i_batch) + + for (int batch = 0; batch < batches; batch++) { - const int8_t *rhs = rhs_in; - const int8_t *lhs = lhs_in + i_batch * rhs_cols; + const int8_t *rhs_ptr = &rhs[0]; + const int32_t *effective_bias_ptr = &effective_bias[0]; #if defined(ARM_MATH_MVEI) - const int32_t row_loop_cnt = rhs_rows / 4; - for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++) + for (size_t row_loop_cnt = rhs_rows / 4; row_loop_cnt != 0; --row_loop_cnt) { - int32_t acc_0 = 0; - int32_t acc_1 = 0; - int32_t acc_2 = 0; - int32_t acc_3 = 0; + const int32_t col_loop_cnt = (rhs_cols + 15) / 16; const int8_t *lhs_vec = lhs; - const int8_t *rhs_0 = rhs; - const int8_t *rhs_1 = rhs + rhs_cols; - const int8_t *rhs_2 = rhs + 2 * rhs_cols; - const int8_t *rhs_3 = rhs + 3 * rhs_cols; - - int32_t col_cnt = rhs_cols; - - while (col_cnt > 0) + const int8_t *rhs_0 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_1 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_2 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_3 = rhs_ptr; + rhs_ptr += rhs_cols; + + int32_t acc_0 = *effective_bias_ptr++; + int32_t acc_1 = *effective_bias_ptr++; + int32_t acc_2 = *effective_bias_ptr++; + int32_t acc_3 = *effective_bias_ptr++; + + uint32_t col_cnt = (uint32_t)rhs_cols; + + for (int i = 0; i < col_loop_cnt; i++) { mve_pred16_t p = vctp8q(col_cnt); col_cnt -= 16; @@ -84,16 +88,16 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, const int8x16_t input = vldrbq_z_s8(lhs_vec, p); const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p); - acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p); + acc_0 = vmladavaq_s8(acc_0, ker_0, input); const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p); - acc_1 = vmladavaq_p_s8(acc_1, ker_1, input, p); + acc_1 = vmladavaq_s8(acc_1, ker_1, input); const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p); - acc_2 = vmladavaq_p_s8(acc_2, ker_2, input, p); + acc_2 = vmladavaq_s8(acc_2, ker_2, input); const int8x16_t ker_3 = vldrbq_z_s8(rhs_3, p); - acc_3 = vmladavaq_p_s8(acc_3, ker_3, input, p); + acc_3 = vmladavaq_s8(acc_3, ker_3, input); lhs_vec += 16; rhs_0 += 16; @@ -101,16 +105,10 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, rhs_2 += 16; rhs_3 += 16; } - rhs += 4 * rhs_cols; int32x4_t acc = {acc_0, acc_1, acc_2, acc_3}; - int32x4_t b = vldrwq_s32(bias); - acc = vaddq_s32(acc, b); - bias += 4; acc = arm_requantize_mve(acc, dst_multiplier, dst_shift); - acc = vaddq_s32(acc, vdupq_n_s32(dst_offset)); - acc = vaddq_s32(acc, vldrhq_s32(dst)); acc = vmaxq_s32(acc, vdupq_n_s32(NN_Q15_MIN)); @@ -120,79 +118,77 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, dst += 4; } - const int loop_cnt = rhs_rows % 4; - for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++) + for (size_t row_loop_cnt = rhs_rows % 4; row_loop_cnt != 0; --row_loop_cnt) { - int32_t acc_0 = 0; + int32_t acc_0 = *effective_bias_ptr++; + + const int32_t col_loop_cnt = (rhs_cols + 15) / 16; const int8_t *lhs_vec = lhs; - const int8_t *rhs_0 = rhs; - int32_t col_cnt = rhs_cols; + const int8_t *rhs_0 = rhs_ptr; + uint32_t col_cnt = (uint32_t)rhs_cols; - while (col_cnt > 0) + for (int i = 0; i < col_loop_cnt; i++) { mve_pred16_t p = vctp8q(col_cnt); col_cnt -= 16; const int8x16_t input = vldrbq_z_s8(lhs_vec, p); const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p); - acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p); + acc_0 = vmladavaq_s8(acc_0, ker_0, input); lhs_vec += 16; rhs_0 += 16; } - rhs += rhs_cols; - - acc_0 += *bias; - bias++; + rhs_ptr += rhs_cols; acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift); - acc_0 += dst_offset + *dst; + acc_0 += *dst; // Clamp the result - acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN); + acc_0 = MAX(acc_0, NN_Q15_MIN); + acc_0 = MIN(acc_0, NN_Q15_MAX); *dst++ = (int16_t)acc_0; } #elif defined(ARM_MATH_DSP) - const int32_t row_loop_cnt = rhs_rows / 2; - for (int32_t i = 0; i < row_loop_cnt; i++) + for (int32_t row_loop_cnt = rhs_rows / 2; row_loop_cnt != 0; --row_loop_cnt) { - int32_t acc_0 = *bias++; - int32_t acc_1 = *bias++; + int32_t acc_0 = *effective_bias_ptr++; + int32_t acc_1 = *effective_bias_ptr++; const int32_t col_loop_cnt = rhs_cols / 4; const int8_t *lhs_vec = lhs; - const int8_t *rhs_0 = rhs; - const int8_t *rhs_1 = rhs + rhs_cols; - rhs += 2 * rhs_cols; + const int8_t *rhs_0 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_1 = rhs_ptr; + rhs_ptr += rhs_cols; for (int j = col_loop_cnt; j != 0; j--) { int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec); int32_t vec_1 = SXTB16_RORn((uint32_t)vec_0, 8); - vec_0 = SXTB16(vec_0); int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0); int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8); - acc_0 = SMLAD(ker_1, vec_1, acc_0); - ker_0 = SXTB16(ker_0); + + acc_0 = SMLAD(ker_1, vec_1, acc_0); acc_0 = SMLAD(ker_0, vec_0, acc_0); ker_0 = arm_nn_read_s8x4_ia(&rhs_1); ker_1 = SXTB16_RORn((uint32_t)ker_0, 8); - acc_1 = SMLAD(ker_1, vec_1, acc_1); - ker_0 = SXTB16(ker_0); + + acc_1 = SMLAD(ker_1, vec_1, acc_1); acc_1 = SMLAD(ker_0, vec_0, acc_1); } for (int k = col_loop_cnt * 4; k < rhs_cols; k++) { - const int32_t lhs_temp = *lhs_vec; + const int32_t lhs_temp = (*lhs_vec); lhs_vec++; acc_0 += lhs_temp * (*rhs_0); rhs_0++; @@ -204,24 +200,26 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift); // Add offset - acc_0 += dst_offset + *dst; - acc_1 += dst_offset + dst[1]; + acc_0 += *dst; // Clamp the result - acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN); - acc_1 = CLAMP(acc_1, NN_Q15_MAX, NN_Q15_MIN); - + acc_0 = MAX(acc_0, NN_Q15_MIN); + acc_0 = MIN(acc_0, NN_Q15_MAX); *dst++ = (int16_t)acc_0; + + acc_1 += *dst; + acc_1 = MAX(acc_1, NN_Q15_MIN); + acc_1 = MIN(acc_1, NN_Q15_MAX); + *dst++ = (int16_t)acc_1; } if (rhs_rows & 0x1) { - int32_t acc_0 = *bias++; - + int32_t acc_0 = *effective_bias_ptr++; const int32_t col_loop_cnt = rhs_cols / 4; const int8_t *lhs_vec = lhs; - const int8_t *rhs_0 = rhs; + const int8_t *rhs_0 = rhs_ptr; for (int i = col_loop_cnt; i != 0; i--) { @@ -239,16 +237,19 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, for (int j = col_loop_cnt * 4; j < rhs_cols; j++) { - const int32_t lhs_temp = *lhs_vec++; - acc_0 += lhs_temp * (*rhs_0++); + const int32_t lhs_temp = (*lhs_vec); + lhs_vec++; + acc_0 += lhs_temp * (*rhs_0); + rhs_0++; } acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift); - // Add offset - acc_0 += dst_offset + *dst; + // Accumulate + acc_0 += dst[0]; // Clamp the result - acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN); + acc_0 = MAX(acc_0, NN_Q15_MIN); + acc_0 = MIN(acc_0, NN_Q15_MAX); *dst++ = (int16_t)acc_0; } @@ -259,13 +260,16 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++) { const int8_t *lhs_ptr = lhs; - const int8_t *rhs_ptr_0 = &rhs[0]; - const int8_t *rhs_ptr_1 = &rhs[rhs_cols]; - const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2]; + const int8_t *rhs_ptr_0 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_ptr_1 = rhs_ptr; + rhs_ptr += rhs_cols; + const int8_t *rhs_ptr_2 = rhs_ptr; + rhs_ptr += rhs_cols; - int32_t res00 = *bias++; - int32_t res01 = *bias++; - int32_t res02 = *bias++; + int32_t res00 = *effective_bias_ptr++; + int32_t res01 = *effective_bias_ptr++; + int32_t res02 = *effective_bias_ptr++; for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx) { @@ -289,21 +293,17 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift); // Add offset - res00 += dst_offset + *dst; - res01 += dst_offset + dst[1]; - res02 += dst_offset + dst[2]; - - // Clamp the result + res00 += (int32_t)*dst; res00 = CLAMP(res00, NN_Q15_MAX, NN_Q15_MIN); - res01 = CLAMP(res01, NN_Q15_MAX, NN_Q15_MIN); - res02 = CLAMP(res02, NN_Q15_MAX, NN_Q15_MIN); + *dst++ = (int16_t)res00; - dst[0] = (int16_t)res00; - dst[1] = (int16_t)res01; - dst[2] = (int16_t)res02; - dst += 3; + res01 += (int32_t)*dst; + res01 = CLAMP(res01, NN_Q15_MAX, NN_Q15_MIN); + *dst++ = (int16_t)res01; - rhs += 3 * rhs_cols; + res02 += (int32_t)*dst; + res02 = CLAMP(res02, NN_Q15_MAX, NN_Q15_MIN); + *dst++ = (int16_t)res02; } const int loop_cnt = rhs_rows % 3; @@ -311,37 +311,41 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in, for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++) { const int8_t *lhs_ptr = &lhs[0]; - const int8_t *rhs_ptr = &rhs[0]; + const int8_t *rhs_ptr_0 = &rhs_ptr[0]; - int32_t res00 = *bias++; + int32_t res00 = *effective_bias_ptr++; for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx) { - int32_t rhs_value0 = (int8_t)rhs_ptr[0]; + int32_t rhs_value0 = (int8_t)rhs_ptr_0[0]; int32_t lhs_value = (int8_t)lhs_ptr[0]; res00 += lhs_value * rhs_value0; - ++rhs_ptr; + ++rhs_ptr_0; ++lhs_ptr; } // Quantize down res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift); - // Add offset - res00 += dst_offset + *dst; + // Accumulate + res00 += (int32_t)dst[0]; // Clamp the result res00 = CLAMP(res00, NN_Q15_MAX, NN_Q15_MIN); *dst++ = (int16_t)res00; - rhs += rhs_cols; + rhs_ptr += rhs_cols; } #endif + + lhs += rhs_cols * batch_offset; } + + return ARM_CMSIS_NN_SUCCESS; } /** * @} end of supportLSTM group - */ + */ \ No newline at end of file diff --git a/Tests/UnitTest/CMakeLists.txt b/Tests/UnitTest/CMakeLists.txt index c6c0ef03..d8d8ab7a 100644 --- a/Tests/UnitTest/CMakeLists.txt +++ b/Tests/UnitTest/CMakeLists.txt @@ -97,7 +97,7 @@ add_subdirectory(TestCases/test_arm_fully_connected_s16) add_subdirectory(TestCases/test_arm_fully_connected_s8) add_subdirectory(TestCases/test_arm_fully_connected_s4) add_subdirectory(TestCases/test_arm_grouped_convolve_s8) -add_subdirectory(TestCases/test_arm_lstm_unidirectional_s16_s8) +add_subdirectory(TestCases/test_arm_lstm_unidirectional_s8) add_subdirectory(TestCases/test_arm_max_pool_s16) add_subdirectory(TestCases/test_arm_max_pool_s8) add_subdirectory(TestCases/test_arm_softmax_s16) diff --git a/Tests/UnitTest/README.md b/Tests/UnitTest/README.md index 5ccf6ff3..0863bfdb 100644 --- a/Tests/UnitTest/README.md +++ b/Tests/UnitTest/README.md @@ -61,7 +61,7 @@ Python package tflite_runtime can be installed with pip and it can also be built Use the -h flag to get more info on supported interpreters. ##### tflite_micro -This interpreter is partially supported. See this comment for more info: https://github.com/tensorflow/tflite-micro/issues/1484#issuecomment-1677842603. +Python package tflite_micro can be installed with pip and it can also be built locally. See this comment for more info: https://github.com/tensorflow/tflite-micro/issues/1484#issuecomment-1677842603. This interpreter is only partially supported, see *Tests depending on TFLM interpreter*. ## Getting started @@ -126,20 +126,14 @@ When adding a new test data set, new c files should be added or existing c files The steps to add a new unit test are as follows. Add a new test test in the load_all_testdatasets() function. Run the generate script with that new test set as input. Add the new generated header files to an existing or new unit test. -### Tests depending on specific TFL versions, patched TFL version or TFLM interpreter - +### Tests depending on TFLM interpreter #### SVDF INT8 -This tests is depending on tflite_micro for its reference data. This is because the operator is only supported by TFLM. -Note that tflite_micro interpreter is currently only supported for SVDF. +This test is depending on tflite_micro for its reference data. This is because the operator is only supported by TFLM. #### LSTM +This test is depending on tflite_micro for its reference data. This is because the operator differs between TFLM and TFLite. -The LSTM tests are using the tflite_runtime as interpreter. -See [Using tflite_runtime](https://github.com/ARM-software/CMSIS-NN/blob/main/Tests/UnitTest/README.md#using-tflite_runtime) for more info. -This patch is needed for the tflite_runtime (or tensorflow if using that): -https://github.com/tensorflow/tflite-micro/pull/1253 - Note that this PR is for [TFLM](https://github.com/tensorflow/tflite-micro) so it has to be ported to [TFL](https://github.com/tensorflow/tensorflow) before building the tflite_runtime. -The issue related to this is: https://github.com/tensorflow/tflite-micro/issues/1455 - +Note that tflite_micro interpreter is currently only supported for SVDF and LSTM. ## Overview of the Folders diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h index a9cbb2db..43c4374a 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_1_cell_gate_bias[11] = {6795, 15999, -6130, -12546, 10504, -8988, -30870, -19325, 70, 25310, 4688}; +const int32_t lstm_1_cell_gate_bias[11] = {-16190, 6797, 24062, 29971, -22780, 17656, 14698, 1849, 4054, 14590, -20709}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h index aec53e24..062bb680 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h index 1cbf293a..43589949 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h index ef4bc138..c7df28f9 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h index 5a6c3d29..a27ae43e 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h index 5012f469..30aaf78d 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h index 36400519..28263946 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #define LSTM_1_BUFFER_SIZE 11 #define LSTM_1_INPUT_BATCHES 1 @@ -10,24 +10,29 @@ #define LSTM_1_TIME_MAJOR 1 #define LSTM_1_IN_ACTIVATION_MIN -32768 #define LSTM_1_IN_ACTIVATION_MAX 32767 -#define LSTM_1_IN_TO_INPUT_MULTIPLIER 1075971584 +#define LSTM_1_IN_TO_INPUT_MULTIPLIER 1075906048 #define LSTM_1_IN_TO_INPUT_SHIFT -2 -#define LSTM_1_IN_TO_FORGET_MULTIPLIER 1083903104 +#define LSTM_1_IN_TO_FORGET_MULTIPLIER 1085883136 #define LSTM_1_IN_TO_FORGET_SHIFT -2 -#define LSTM_1_IN_TO_CELL_MULTIPLIER 1082296832 +#define LSTM_1_IN_TO_CELL_MULTIPLIER 1084231552 #define LSTM_1_IN_TO_CELL_SHIFT -2 -#define LSTM_1_IN_TO_OUTPUT_MULTIPLIER 1080149504 +#define LSTM_1_IN_TO_OUTPUT_MULTIPLIER 1085274240 #define LSTM_1_IN_TO_OUTPUT_SHIFT -2 -#define LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER 1164713600 +#define LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER 1523696256 #define LSTM_1_RECURRENT_TO_INPUT_SHIFT -2 -#define LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER 1155545344 +#define LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER 1511291392 #define LSTM_1_RECURRENT_TO_FORGET_SHIFT -2 -#define LSTM_1_RECURRENT_TO_CELL_MULTIPLIER 1154073088 +#define LSTM_1_RECURRENT_TO_CELL_MULTIPLIER 1523716992 #define LSTM_1_RECURRENT_TO_CELL_SHIFT -2 -#define LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER 1082469760 +#define LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER 1525092864 #define LSTM_1_RECURRENT_TO_OUTPUT_SHIFT -2 -#define LSTM_1_HIDDEN_MULTIPLIER 1993146898 +#define LSTM_1_FORGET_MULTIPLIER 1073741824 +#define LSTM_1_FORGET_SHIFT -14 +#define LSTM_1_INPUT_MULTIPLIER 1073741824 +#define LSTM_1_INPUT_SHIFT -17 +#define LSTM_1_HIDDEN_MULTIPLIER 1522160019 #define LSTM_1_HIDDEN_SHIFT -22 -#define LSTM_1_HIDDEN_OFFSET 66 -#define LSTM_1_OUTPUT_STATE_OFFSET 66 +#define LSTM_1_HIDDEN_OFFSET -23 +#define LSTM_1_DATA_OFFSET 128 +#define LSTM_1_OUTPUT_STATE_OFFSET -23 #define LSTM_1_CELL_STATE_SHIFT -12 diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h index 2edce470..f7adf487 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_forget_gate_bias[11] = - {-25627, -3119, 1396, -28731, 20399, -16645, -15073, 19313, 8401, -21968, -27812}; + {-23170, -13466, -6110, 22504, -22652, 25549, -26211, -32267, 15774, -29318, 6943}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h index 6b56429e..e70544d5 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h index 9a71b026..ad595f7e 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h @@ -1,17 +1,17 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_input[220] = { - 72, 126, -20, 19, 79, 71, 8, -72, -127, 118, 91, -27, 52, 121, 74, -4, 64, -14, -15, -105, - -56, 36, -8, 49, -63, -119, -4, -31, -28, -8, -97, 82, -85, 68, -28, 98, -1, -65, 3, -99, - -41, -11, -65, -55, 88, 66, -56, -2, -33, 108, -109, -105, 91, -101, -96, 119, -2, -109, -5, -29, - 24, 43, -38, -128, -85, -114, -116, 12, -41, -117, 109, 104, 34, -49, -54, 96, 30, -28, -72, 3, - 101, -84, -71, 1, 105, -65, -33, -49, 55, -49, 110, -92, 77, -93, 102, 28, 101, -113, -120, -88, - 119, -124, -47, 98, -114, -122, 124, 57, -36, 18, 70, -72, -84, -75, -36, -110, -118, -1, 88, 67, - -99, -115, 70, -16, -85, 49, 89, -104, 42, 86, -103, 20, -50, 113, -88, -81, 57, 45, 0, -33, - 48, -27, 34, 73, 84, -64, 80, -115, 75, 58, -9, 32, 111, -64, -12, -40, 117, -97, -16, -33, - 44, -51, 12, 15, -66, 82, -7, -38, 77, 116, 60, -39, 19, -29, 47, 84, -72, -109, -29, -86, - -30, -69, -78, 90, -40, 104, -36, -110, 26, -93, -24, -4, -18, 61, 65, 77, 6, -40, 89, -92, - 116, -89, 112, -8, -68, 28, 98, -56, 108, 63, 109, -72, -101, -19, -58, 86, -73, -46, -70, -89}; + 121, -117, 50, -102, -68, -103, 9, -120, 48, 53, 25, 69, -77, 15, -34, 32, -70, 29, -118, -58, + -92, -17, -109, 96, -113, -5, 87, 45, -116, 98, 98, -41, 31, 40, 27, 65, -115, 35, 47, -96, + 20, 30, 101, 70, -16, 102, 70, -117, -22, -45, 49, 29, -38, 111, 35, 76, -91, 71, 1, 23, + -69, -32, -85, 13, 39, -32, -123, 69, 24, 110, -1, 22, 112, -79, 73, 125, 68, 53, -16, 100, + 24, 20, -61, 79, 88, 94, 8, -61, 26, 13, 21, 99, -119, -12, -100, -46, 79, 10, 26, 96, + -78, 75, -18, -42, 0, 65, -124, 43, 91, 42, -62, -95, 80, 19, -85, -42, -128, -106, -21, -90, + -75, 66, 94, -10, -122, -55, -8, -126, 51, 8, -21, 71, -93, 67, -23, 88, 3, -50, 16, 124, + -9, -69, 114, 73, 116, 113, 83, -3, -101, 34, -113, -18, -101, 13, -48, 48, 63, 56, -114, 80, + 8, -20, -23, 63, 115, -83, -59, -72, -21, -57, -97, -6, 100, 82, 118, 83, 126, -111, 36, -114, + 89, -38, 54, 113, 36, -1, -119, -122, -23, -71, -82, 87, 44, 101, -13, -3, -102, -97, 42, 44, + 61, 31, -89, -6, -33, -57, 16, -16, -27, -20, -52, -104, 120, 30, -125, -9, 96, 67, 77, 125}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h index 0f945278..ed45a535 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h @@ -1,7 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_1_input_gate_bias[11] = - {13255, 22061, -13996, 8755, -14507, 4146, 10874, -20199, -24289, -30146, 28463}; +const int32_t lstm_1_input_gate_bias[11] = {-32410, -104, 21567, -21097, 12535, 259, 8401, 10604, 24974, 30367, -9986}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h index f47f1480..0b9ac91c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h index f0f82cbe..a8c39211 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_input_to_cell_eff_bias[11] = - {-23285, 35199, 3342, -20866, 1928, -43932, 15850, -46717, -20794, 54878, 11600}; + {-70974, 15629, 50174, -12269, -26876, 14328, 56938, -90951, 37206, 49150, 37019}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h index b2d99b5c..bce86cc3 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h @@ -1,19 +1,18 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_input_to_cell_w[242] = { - 88, 6, 55, 61, -92, 45, 32, -22, -111, 74, -20, 84, 127, -83, 65, -70, -58, -73, -49, - -114, -115, -65, -64, 119, 121, -35, -69, -25, 5, -38, 46, 27, -80, 41, 83, 47, 4, 32, - -95, -56, 72, 49, -100, 66, 68, 98, 45, -124, -126, 32, 45, -51, 105, -16, 32, 60, 23, - 109, -104, -50, 47, -98, -8, 71, -100, 16, 80, 33, -10, -47, -48, 52, -35, 85, 16, 22, - 109, 72, -78, -76, 49, -82, 7, -70, 0, -105, -39, 0, 38, 79, -47, 116, -23, -53, -33, - -113, -98, -48, -3, 55, 16, -62, -3, 41, -60, 39, 19, 88, -110, 95, 103, 80, -4, -16, - -92, -51, -78, -80, 112, 102, -116, -2, -91, -76, 89, 24, 99, -17, -118, 32, -46, -127, -7, - 33, 52, 41, -51, 19, 73, 78, 13, 111, 99, -50, -78, -23, 99, -27, 90, -59, 1, 1, - 27, -77, 76, -100, 14, -12, 112, -70, -56, -33, -14, -79, 46, -81, -41, 37, 97, -41, -98, - 38, -120, 25, -3, 89, -111, -101, 123, -80, -94, 0, 97, -120, 62, 116, -8, 18, -37, -104, - 57, -14, 1, -121, 89, -75, 113, 26, 95, -37, 45, -97, 35, 41, 62, 23, -70, 124, -94, - 81, 125, 127, -51, -76, 17, 49, -80, -78, 76, -86, 103, 34, -69, -26, 50, -55, -66, 56, - -52, -63, 48, -93, 120, 15, -77, 80, 27, 89, -37, -87, -23, 80}; + -112, 13, 117, -90, -36, -7, -41, 19, -90, 44, 121, -52, -58, -77, -18, 126, -63, -105, -45, -53, 88, + -109, -99, -26, -36, 64, -3, -127, -109, 114, 59, 108, -56, 26, 102, -100, -48, -11, -90, 81, 49, 94, + -30, 107, 63, 65, 66, -74, -24, -1, -57, -109, -26, -75, 48, 93, -51, 8, 76, 94, 66, -76, -58, + -15, 98, 93, -95, 45, -23, 104, 60, -13, -35, -73, 10, -8, -51, -43, -100, 55, 101, -53, -85, 12, + -93, 0, 14, -59, -49, -3, 3, -74, -107, -124, -85, 72, -30, 104, -57, 18, 114, -93, -72, 22, 65, + 120, -34, -5, 120, 63, 3, -93, 69, 27, -124, -13, -93, 84, 32, -5, 88, -49, 93, 2, -72, -56, + 101, -30, 21, -1, -64, 54, 108, -86, -12, -88, 102, 39, 80, -12, 39, -1, 108, 10, -28, 31, -62, + 5, -107, 102, 56, -26, 3, 69, 115, -103, -73, -52, -96, 27, 102, -127, -118, -22, 45, 20, -60, -53, + -20, -47, -52, 43, -66, -114, -74, 0, -66, 80, 103, 28, 84, -17, -85, -81, 10, 109, -5, -2, -76, + 17, 85, -17, -51, 92, 76, -21, 7, -11, -87, 48, -27, -9, 124, 15, 120, -123, 64, -71, -59, 32, + -35, 90, 112, 67, 13, 13, 6, -1, 67, -89, 7, 93, 18, 125, 121, 48, -13, -91, 41, 30, 31, + -91, 65, 114, -98, -88, -6, -48, 18, 50, 115, 10}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h index 5c84d6b1..aba8cb81 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_input_to_forget_eff_bias[11] = - {-58523, -69423, 59764, -63035, -12625, -52357, 2463, 52593, -6703, -113488, 114652}; + {19454, 742, -39262, 13544, -61308, 48333, 43165, -115211, -17250, -34822, 33823}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h index c89123b8..fb6dab68 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h @@ -1,19 +1,19 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_input_to_forget_w[242] = { - -98, -98, -102, 30, -76, 42, -40, 98, -37, 29, -43, 60, -37, 26, 114, -37, 100, -85, -62, - 86, -71, -56, -83, -35, 105, 60, 17, -21, 109, -77, -122, -64, -113, -53, 49, -102, -84, -89, - 101, 61, -99, -47, -64, 33, 120, -10, 64, 122, 71, 77, -15, 106, -9, -12, -50, 46, 49, - 98, -126, -11, 81, -121, 52, -119, -5, 48, 17, -62, 20, -125, 39, 63, -126, -109, -102, -81, - -31, 49, 84, 64, -45, 100, -1, -76, 103, -84, 19, 16, -110, 18, 71, -125, 66, -46, 45, - 111, -4, 54, -106, -83, -42, 24, -26, -62, -46, -20, 70, -124, -1, 78, 8, -59, -87, 104, - -54, 53, -38, 11, 104, -5, -113, 36, -62, -15, -117, -39, 48, -7, -44, 61, 22, -86, 50, - 105, 72, -91, 29, -19, 19, 64, 17, -3, -110, -5, 122, -79, 79, 43, -15, -23, -69, -13, - 11, -47, 19, -104, 82, 40, -6, 115, 83, 101, 112, -80, 18, -49, -70, 83, 98, -34, 75, - -80, 91, -42, -113, -79, -103, 56, 25, -49, -106, -65, 123, -90, -127, 121, 27, -40, 82, -40, - -13, -106, 100, 18, -28, -5, 86, 16, 64, 50, -71, -118, 2, -32, -2, 94, -125, -98, 4, - -49, -26, -127, -101, 17, 121, -91, -11, -28, -114, -74, 118, 79, -27, 50, 120, 126, 91, -108, - 102, 15, 88, 37, -7, 41, -90, 118, 84, -1, 68, 22, 88, 99}; + 63, 63, 113, 40, -74, -51, -21, 79, -91, 95, 110, -57, -95, -80, -26, -40, 0, 87, -99, + 102, 93, 122, -96, 62, 14, 40, -90, 14, 125, 3, -2, 50, 113, -37, -109, 24, -22, 38, + 29, -80, 21, -80, 71, 23, -109, 107, 118, -109, -5, -62, 112, -24, 41, -24, -75, 57, -2, + 118, 111, -112, -125, -3, -22, -17, -123, -111, 114, -96, 66, -95, -123, 126, 20, -22, -47, 72, + -118, 50, -122, 18, 99, 70, -126, 3, 96, -74, -60, 79, -40, -115, -84, -83, 90, -77, -113, + 60, 10, -111, 40, 69, -89, -83, -34, 14, 3, -97, 97, 71, 72, 98, 73, 20, -31, -5, + 27, -79, -95, -44, -48, 123, 12, -95, -92, 17, 119, 25, 77, 127, 56, -4, 18, -23, 74, + 73, 12, 82, 69, 105, -106, 6, 51, 21, 32, 28, 116, 123, -118, -124, -47, -12, 66, 65, + 69, -43, -122, -67, 80, -112, -46, -24, -47, -1, -28, -73, -113, 116, -18, -81, -33, -18, 105, + -116, -71, -68, 63, 26, -98, -96, -46, 68, 64, 82, -10, 55, -95, -89, 62, -13, -84, 46, + -23, -75, 89, -41, 53, -21, -51, -35, 56, -116, 13, -74, 73, -75, 32, -55, -81, 74, 91, + 110, -78, 67, 22, 23, -92, 3, 66, -98, 87, -91, 89, -93, 37, 118, 124, 34, -61, 78, + 85, 18, -17, -14, -21, 3, -77, -63, -105, 126, 83, -78, 26, -82}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h index 5d88fdfa..6773e43c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_input_to_input_eff_bias[11] = - {73287, 57517, -29484, -49741, 12885, -64462, 41210, -61287, 48415, -27970, 67375}; + {-123034, -15592, 20415, -22377, 65527, -19069, 31697, -19220, 96526, 111007, 17534}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h index acf45cc6..3431c7c9 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h @@ -1,19 +1,19 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_input_to_input_w[242] = { - 69, 103, 18, 41, 114, 100, 107, 9, -38, 70, 15, 41, 54, 51, -106, 37, -25, -71, -101, - 102, -85, -36, 106, -24, 120, 37, -86, 53, 17, 50, -94, -22, 42, 6, 109, -82, -27, -28, - -66, 92, 47, 14, 127, -114, -83, -125, 93, -114, 7, 90, 57, -54, -110, -98, 103, 116, 70, - 41, -57, -29, -80, 89, -102, -115, 56, 124, -52, -19, 63, 107, 101, 26, -26, -88, -6, -109, - -21, 107, -21, 8, -40, -118, 30, -66, -118, 13, -113, -115, -55, -15, -71, 12, 120, 86, 107, - -55, 17, -3, -99, 19, -85, -34, 25, 58, -83, 95, 24, 14, 67, 70, -12, -99, -99, -104, - 40, 38, -28, -123, 44, -90, -95, 25, 117, 36, -44, 87, -28, 39, -102, -106, 90, -122, 39, - -120, -2, 50, 84, -126, 115, 31, -58, 19, 85, 46, 88, -123, 55, 77, -9, -21, 94, 27, - -90, -24, -67, -64, -22, 119, 46, 1, -15, 50, -19, -18, 21, -81, -71, -118, -113, -73, 44, - -26, 48, 22, 78, -63, 107, -81, -79, -34, 70, 82, 94, -23, -46, 104, 38, -3, 36, 103, - -4, 92, 45, -77, 70, 8, -34, 100, 15, -36, 83, -14, -17, 29, 106, 24, 91, 101, -33, - -109, -53, 32, -41, -20, 72, 31, -118, -50, -79, 3, 91, -99, 124, 117, -16, -35, 65, -18, - -115, 71, 33, -43, 74, -98, -37, 26, 124, -73, -19, -44, 66, 110}; + -83, -99, -16, -47, -67, 80, -5, -85, -11, -108, 1, -44, -66, -100, 127, -123, 56, -65, 13, + 93, -124, -35, -93, -67, 36, 71, -50, -75, -48, 2, -97, 5, 45, 91, 75, 2, -83, -112, + 85, -65, 107, 62, 25, -37, -111, -74, -80, -26, -18, 76, -15, 50, -120, 75, 104, -76, 21, + 31, 65, -30, -116, 117, 123, 105, -42, -68, 31, -41, -58, 114, -28, 46, -103, -80, 34, 89, + 5, -90, 11, 86, -9, -102, 90, -54, 27, -29, 15, 36, 26, 16, 27, 64, 3, 51, 69, + 61, -123, 118, 23, -23, 60, 125, -87, -52, -108, 83, 94, 4, 93, -110, -35, -96, 124, 73, + -24, -121, -88, -62, 7, 87, 111, 38, 123, -88, 59, -27, -106, -60, -9, -112, 70, -15, -31, + -5, 88, 124, -44, 78, -81, 39, 78, -67, -113, -17, 81, 10, 8, 2, -8, 6, 40, 5, + -14, 3, 108, 69, 39, 64, 85, 64, 21, 52, -115, -99, -20, -61, -102, -122, 55, 18, -98, + -36, 0, -95, -78, 18, -81, -4, 33, 78, -23, 30, 124, 41, 71, 106, 101, 59, -81, -46, + -26, 92, 49, -6, 97, -37, -111, 93, 67, 126, 63, -36, 12, -29, 100, 44, 16, 77, -47, + 105, -54, 115, 17, 19, 99, 37, 44, -66, 29, -108, -53, 87, -30, -77, -44, 123, -111, 71, + -42, 34, 14, 91, 12, 118, 74, 64, -106, 106, -1, 6, -102, -19}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h index e3c21161..9ba660c8 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_input_to_output_eff_bias[11] = - {49470, -17340, 19595, 64019, -58643, 48766, -43591, -26418, 70460, 25909, -56678}; + {-37894, 15877, 85451, -15991, 37318, 38600, -31336, -20295, -76464, 56669, -2404}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h index 3672d460..ef8e04f1 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h @@ -1,19 +1,19 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_input_to_output_w[242] = { - 25, -27, -110, 8, 91, -32, -98, 74, 81, -32, 106, -29, -57, -41, -17, 95, 25, 39, 124, - -45, 61, -41, -37, -23, -125, 116, -110, 49, 28, -27, -34, 79, 9, 108, 27, -14, 82, -113, - -82, -62, 55, -17, -30, -28, 22, 50, 97, 5, -72, 27, -36, 29, 20, -69, -79, -13, 46, - 89, 56, 93, 42, -35, -45, 37, -58, -113, -94, 63, 76, -119, 58, 24, 127, 103, -88, -82, - -37, 98, 104, 67, 18, 113, -65, -26, 39, 43, 127, -42, -88, 65, -62, -107, 64, -22, 23, - -40, 93, 107, -3, -92, -10, -22, -69, 122, -51, 10, 0, 19, -102, -106, -93, -17, 23, 41, - 59, 29, 124, 79, 74, 63, 114, 96, -85, 42, 112, -78, -6, -115, 9, -127, -80, -123, 60, - 95, -42, -42, -114, -113, -100, -112, -28, 8, 48, -33, -48, 111, 55, 3, -80, 4, 42, -19, - 50, 93, 39, 8, 94, 115, 12, -99, -118, -95, -23, -42, 30, 32, -58, -101, 69, -115, -1, - -72, -122, 57, -11, 101, 60, -99, 28, 26, 84, -33, -126, -9, -104, -59, -29, 122, -30, 120, - -33, 117, 27, 59, 104, 55, -21, 66, 67, -64, 52, 78, -120, -29, -88, -74, -6, 44, -101, - -125, 123, -4, 14, 88, -17, 103, -6, 2, 11, 37, -22, -84, -125, 67, -121, 54, 23, -51, - 113, -83, 15, -67, 10, 109, -88, 43, 32, -28, -79, 84, -74, 74}; + -59, 85, -25, 59, -52, -15, 21, -113, -114, -7, -16, -78, -58, -59, -20, -46, 24, 4, -69, + 56, 25, -13, 107, 48, -102, 24, -41, -75, 97, 70, 44, 17, -116, 54, -101, 10, 8, -36, + -66, -44, -111, 114, 22, 70, 62, 48, -16, -6, 67, 121, 55, -72, 120, 58, 80, -105, -98, + 51, 93, -115, 110, 79, -4, -18, -51, 11, 117, -62, -4, 93, 23, 32, -60, 42, -121, 54, + 54, -83, 63, 55, -35, -8, -76, -114, 24, -108, 81, 45, 44, -15, 100, 52, -22, 25, 23, + -48, -38, -30, -127, 69, 18, -4, 84, 97, 16, 61, -52, 20, 70, 39, 69, -78, 78, -43, + -110, 95, 94, -50, 77, -5, -78, 21, 23, -32, 46, 40, -105, 17, 65, 40, 12, 0, 30, + -88, 119, 123, -120, -118, 21, 33, -58, 73, -117, -52, -10, -93, 72, -16, 49, -114, -94, 10, + 15, -78, -120, 10, -90, -9, -94, -28, -76, 80, -91, -63, -10, 78, -24, 104, 71, -17, 27, + 122, -5, 46, 27, -120, -21, -51, -29, -63, -67, 96, -115, 24, -117, -127, -52, 59, 73, 3, + 49, 33, -47, -20, -12, -16, 28, -4, -79, -94, -96, 123, -120, -57, 13, -18, 28, -48, 68, + 102, 24, 114, -118, 124, 60, -41, 97, 78, 11, 40, 38, 120, -91, -86, 9, -94, 120, -60, + 98, -119, -50, 49, -43, -82, -6, -38, 113, -98, -25, 113, -27, -101}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h index f0cc061f..324e2dbc 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_output_gate_bias[11] = - {23870, 1732, 7691, -877, -23955, 30718, -22855, 11982, 28860, 27829, -31334}; + {22266, 16773, 25291, -17527, -11578, 16072, 21528, 3001, -28336, 29661, 30876}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h index 4fd42944..58db3fa0 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h index a01813a8..d0b3ac83 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h @@ -1,11 +1,11 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_output_ref[110] = { - 127, 124, 81, 62, 72, 51, 115, 63, -83, 119, 79, 115, 127, 94, 106, 54, 68, 104, 63, -77, 116, 85, - 127, 127, 94, 127, 67, 108, 73, 57, -35, 120, 103, 48, 127, 44, 97, 16, 64, 89, 62, 85, 101, 74, - -50, 88, 23, 60, 40, 47, 69, 55, 104, 99, 90, -101, 95, 33, 59, 71, 59, 76, 41, -54, 121, 127, - -106, 109, 21, 41, 50, 52, 92, 40, -34, 78, 88, -84, 115, -7, 41, 58, 55, 117, 39, 109, 104, 89, - -103, 115, 1, 56, 36, 55, 90, 44, 62, 107, 87, -68, 87, -44, 29, 47, 25, 76, 29, -41, 108, 89}; + -24, -26, 42, -30, -91, 13, 9, -29, -16, 74, -38, -24, 37, 46, -38, -25, -10, -9, -37, -12, 110, 47, + -25, 52, 113, -35, -15, -22, 36, -41, -16, 101, 81, -26, 78, 119, -27, -83, -54, 14, -31, -19, 127, 79, + -27, 98, 118, -25, 66, -48, 61, -36, -9, 127, 85, -29, 100, 104, -71, 36, -24, 40, -28, 21, 121, 60, + -26, 113, 127, -39, -50, -33, 15, -47, -3, 120, 52, -31, 118, 127, -59, 34, -65, 42, -42, -11, 100, 53, + -26, 84, 67, -47, 36, -65, 29, -40, -16, 58, 101, -29, 85, 116, -73, -12, -115, 32, -41, -1, 104, 56}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h index 3855269c..8c1d1a91 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -int8_t lstm_1_output_state[11] = {66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66}; +int8_t lstm_1_output_state[11] = {-23, -23, -23, -23, -23, -23, -23, -23, -23, -23, -23}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h index 15335fa2..2cdf2f40 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h index c8b1358f..daa15c33 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h index a50291e8..4ad4495f 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h @@ -1,13 +1,12 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_recurrent_input_to_cell_w[121] = { - 111, 31, -22, 113, 102, -33, -15, 117, 58, 88, 6, -12, -19, 102, 37, 80, -109, -9, - -83, 93, 110, -45, 57, -122, 0, -120, 57, 61, -102, -47, -7, -101, -126, 93, -13, 117, - -51, -72, -33, 76, 92, -105, 0, -92, -56, -81, 37, 21, -124, 34, -78, -120, 30, -70, - 53, 113, 57, -12, -11, 95, 117, 56, 47, -68, 110, 41, -72, -32, -85, 122, -17, -49, - 111, -8, -55, -107, 40, -111, -94, 40, 111, 72, -16, 39, -123, -119, 60, 118, 62, -55, - -100, -60, 6, -74, 113, -80, -72, -88, -18, -27, 11, 62, -99, 127, -54, 12, 41, 13, - -76, -19, 7, 62, -29, -11, -94, -74, -53, -110, -25, -32, 74}; + 119, -7, -15, -96, -39, 87, -17, 17, 77, -113, 50, -26, 91, 53, 14, 56, 84, 73, -61, 118, -9, + -50, -72, 30, -72, -18, 41, 13, 14, 41, 26, 5, -40, -15, 94, -16, 80, -11, -66, -30, 25, -115, + -51, -93, -5, 41, 96, 97, -52, -101, 54, 16, -73, 96, -55, -63, -127, -9, -83, -114, 7, 32, 70, + -17, -70, -40, 116, -10, -9, 20, -36, -14, 40, 25, 127, -63, -18, 74, -51, -125, -53, 119, -83, 75, + 0, -77, -60, -10, 74, -54, -37, 92, -7, -105, -38, 52, 23, 20, 52, 39, -74, 50, -4, 62, -87, + 64, -62, 72, 37, -68, -14, 92, -111, -80, 24, 18, -86, -59, -32, 125, 15}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h index 05bcd09b..f6ad4f00 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h @@ -1,12 +1,12 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_recurrent_input_to_forget_w[121] = { - -68, 67, 106, -51, 107, -86, -60, -25, -67, -44, -127, -1, 60, 73, -127, 91, -20, 95, -80, -108, 55, - -40, -31, -73, 41, -101, -14, 85, 113, 0, 15, -71, -5, -64, 111, -98, -66, 109, 6, 44, -103, 6, - 4, 104, 109, -29, 118, -55, 22, -35, 83, 6, -4, -65, -14, 102, 91, 121, 57, -120, -55, -60, -33, - -79, 15, -103, 53, -122, -78, -54, -8, 115, -67, 27, -1, 74, -114, -22, -82, 28, -18, 32, -21, 84, - -66, 14, -82, 53, -25, -26, -126, -73, -46, -86, 16, -118, 22, -42, 57, -90, 121, -80, 109, -55, -75, - -94, -85, 80, 121, 45, -33, 9, 58, -95, 29, 73, -65, 33, 10, 66, 36}; + 110, -108, -7, -84, 96, 93, -39, -46, -107, 73, 66, -74, 107, 49, 14, 27, 37, -74, -97, -69, 22, + 78, 47, 124, -103, -74, -120, -107, 71, 3, 81, 127, 86, 69, -68, 7, -119, -67, -9, 8, 60, 78, + 33, -59, 26, 64, 13, -51, 7, 0, 118, 37, 123, -104, -71, -18, 113, -107, -40, 33, 23, -42, -40, + 74, -57, 50, -22, -45, -120, 68, -11, -6, 36, 103, -73, 51, -3, -41, -22, -92, 92, -20, 122, -127, + -54, 73, 6, -47, -75, -74, 118, -19, -12, -97, 42, 31, -59, -45, -36, -108, 41, 91, -49, -1, -9, + 79, 68, 36, 13, -66, 102, 126, -3, -17, 58, -87, 120, 6, 68, 95, 65}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h index 86a2a835..ba796abf 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h @@ -1,12 +1,12 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_recurrent_input_to_input_w[121] = { - -75, -104, 8, 76, -10, -67, -120, -91, -15, 110, 48, 103, 85, -29, 34, -111, 98, 70, -67, 24, 72, - -101, -87, 26, -61, 122, 14, 52, 99, 113, 127, 77, 6, 118, -88, -18, -94, 14, -38, 91, -2, -72, - -86, 81, -122, 109, -19, -61, -11, -14, 116, 78, 4, 125, 73, 56, -66, -49, -23, 100, -102, 27, 75, - 118, -124, -113, -16, 60, -30, -96, 47, 119, -63, 104, -123, -91, -97, -39, -127, -25, 1, 44, 8, -21, - -25, 4, -52, -82, 107, -60, -125, -82, 64, 101, 24, -67, -27, -52, 119, 97, 26, -69, 92, 77, 16, - -15, -11, -34, -8, -94, 6, -32, -60, -94, -50, 125, -2, -73, -118, -33, 96}; + 98, -54, 29, 95, -124, -26, 17, 6, 17, 65, -105, 24, -101, -115, 49, 15, -32, 100, 51, -54, 125, + -43, -47, 39, -76, -114, 61, -106, 38, -115, 7, 38, 111, 1, -28, 117, 93, 121, -103, -37, -88, 121, + 95, 65, 50, 121, -66, 28, -125, -30, -19, -15, -70, -59, -119, -107, -43, 122, -68, -123, -101, -22, 127, + 81, -63, -126, -96, 74, -111, -102, -127, 61, -12, -60, 91, -109, 24, 3, -86, 6, -19, -51, -48, -113, + -50, 16, -5, -31, -76, -14, -24, 47, 112, -6, -19, 54, 51, 21, 122, 117, 118, -9, 11, 121, -84, + -74, 86, -90, 66, 51, -19, 10, 94, -60, 75, 57, -84, -30, -22, -3, 55}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h index d0a3c658..62b117c4 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h @@ -1,13 +1,12 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_1_recurrent_input_to_output_w[121] = { - -110, 10, -9, -112, -13, 8, 110, 122, 49, 124, -29, -20, -8, -66, -66, -60, -71, 10, - 25, -107, -63, -43, -74, -41, 27, 127, 42, 9, -6, 97, -33, 71, -86, 115, 88, 32, - -75, -117, 70, -108, 54, 42, 92, -73, -79, -123, -7, 58, -44, 77, -11, -84, -102, -8, - 77, 42, -21, -100, 64, 43, 115, -42, 117, 66, 1, 100, -21, -122, -45, 57, 9, -36, - -28, -94, -94, -90, -118, -123, -30, 42, 21, -56, -7, -125, -79, -16, 114, -68, 90, 69, - -99, -43, 63, 31, -38, -22, 43, -52, -33, 37, -24, -81, 4, 88, 124, 57, 76, -73, - 80, 22, -10, 10, 3, 32, -37, 43, -104, -113, -75, 81, -66}; + 9, -100, 91, -34, 118, -95, 69, 47, 8, 79, 67, 6, 66, 122, -52, -3, 11, 57, -93, -70, 6, + 48, 7, 90, 44, -18, 74, -114, 57, -5, -110, 76, -117, -121, 59, -113, 68, 94, -64, 75, -81, 17, + -85, -65, -85, -13, -40, -5, -121, 26, -29, 54, -113, 37, -6, 127, -67, -12, 127, -23, -85, -32, 38, + 23, -123, 58, 108, -80, 68, -120, -101, 114, -107, 4, 84, 40, 85, -3, 123, 104, 125, -55, 91, 34, + -120, -80, -106, 9, -110, 48, 49, -36, 53, -94, 23, -23, -65, -26, -58, 82, -21, -20, 94, -31, 86, + 12, -89, 114, -62, 126, 27, 8, 45, -54, -122, 95, 7, -83, 19, 24, 44}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h index 8dcc2b5e..06b79d01 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_recurrent_to_cell_eff_bias[11] = - {-36696, -9570, 29700, -792, 23364, -35970, 10032, 1518, 24156, 594, 18810}; + {1449, 7889, -736, -4554, 2622, -9522, 4094, -4393, 1656, 667, -2484}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h index 26a79590..599b9096 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_recurrent_to_forget_eff_bias[11] = - {16368, 132, 2706, -3498, -8976, 4224, 11550, 5280, 29502, 198, -7986}; + {1081, 460, 3105, -1541, 3726, -253, -506, -2530, -5198, 2185, 12259}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h index 0d542ac8..3de075dc 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_recurrent_to_input_eff_bias[11] = - {15840, -11748, -32208, 6204, -18348, 6666, 12276, 20724, -132, -5082, 15510}; + {414, 437, -3772, 8211, -6992, -7429, -8441, -8694, 6164, 7199, 1679}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h index 2e395eff..2f4d8378 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h @@ -1,7 +1,7 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int32_t lstm_1_recurrent_to_output_eff_bias[11] = - {-9900, 30954, -8778, -7920, 16236, -25410, 38412, 21582, -594, -20460, 15576}; + {5957, 2254, -368, -4968, -6785, 713, 2185, 2806, -5497, 6693, 230}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h index ecd25010..f1ab6ed2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #include "cell_gate_bias_data.h" #include "cell_norm_coeff_data.h" #include "cell_state_data.h" diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h index 16a9bd84..7c52e627 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_cell_gate_bias[7] = {-18223, -7792, 2026, 24820, 5477, -24022, -32276}; +const int32_t lstm_2_cell_gate_bias[7] = {16402, 13081, 21684, 28843, 19606, 13836, -23310}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h index 84fb2981..711d66c2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h index 25d429c2..1f6e2166 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -int16_t lstm_2_cell_state[14] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +int16_t lstm_2_cell_state[7] = {0, 0, 0, 0, 0, 0, 0}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h index 62963262..6341b7a2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h index 9abfc468..0500eabb 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h index b5d67ed2..2f7d5c88 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h index 2d125a88..da8a0fb4 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h @@ -1,33 +1,38 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once -#define LSTM_2_BUFFER_SIZE 14 -#define LSTM_2_INPUT_BATCHES 2 -#define LSTM_2_DST_SIZE 126 +#define LSTM_2_BUFFER_SIZE 7 +#define LSTM_2_INPUT_BATCHES 1 +#define LSTM_2_DST_SIZE 63 #define LSTM_2_TIME_STEPS 9 #define LSTM_2_NUMBER_UNITS 7 #define LSTM_2_NUMBER_INPUTS 6 #define LSTM_2_TIME_MAJOR 0 #define LSTM_2_IN_ACTIVATION_MIN -32768 #define LSTM_2_IN_ACTIVATION_MAX 32767 -#define LSTM_2_IN_TO_INPUT_MULTIPLIER 2143970816 +#define LSTM_2_IN_TO_INPUT_MULTIPLIER 2146016128 #define LSTM_2_IN_TO_INPUT_SHIFT -3 -#define LSTM_2_IN_TO_FORGET_MULTIPLIER 2037819904 +#define LSTM_2_IN_TO_FORGET_MULTIPLIER 2131870080 #define LSTM_2_IN_TO_FORGET_SHIFT -3 -#define LSTM_2_IN_TO_CELL_MULTIPLIER 2115009024 -#define LSTM_2_IN_TO_CELL_SHIFT -3 -#define LSTM_2_IN_TO_OUTPUT_MULTIPLIER 2088505856 -#define LSTM_2_IN_TO_OUTPUT_SHIFT -3 -#define LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER 1470197120 +#define LSTM_2_IN_TO_CELL_MULTIPLIER 1079288192 +#define LSTM_2_IN_TO_CELL_SHIFT -2 +#define LSTM_2_IN_TO_OUTPUT_MULTIPLIER 1077017984 +#define LSTM_2_IN_TO_OUTPUT_SHIFT -2 +#define LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER 1351661952 #define LSTM_2_RECURRENT_TO_INPUT_SHIFT -3 -#define LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER 1504444160 +#define LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER 1367683200 #define LSTM_2_RECURRENT_TO_FORGET_SHIFT -3 -#define LSTM_2_RECURRENT_TO_CELL_MULTIPLIER 1454177792 +#define LSTM_2_RECURRENT_TO_CELL_MULTIPLIER 1374072192 #define LSTM_2_RECURRENT_TO_CELL_SHIFT -3 -#define LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER 1455750656 +#define LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER 1379062144 #define LSTM_2_RECURRENT_TO_OUTPUT_SHIFT -3 -#define LSTM_2_HIDDEN_MULTIPLIER 1537938273 +#define LSTM_2_FORGET_MULTIPLIER 1073741824 +#define LSTM_2_FORGET_SHIFT -14 +#define LSTM_2_INPUT_MULTIPLIER 1073741824 +#define LSTM_2_INPUT_SHIFT -17 +#define LSTM_2_HIDDEN_MULTIPLIER 1667670651 #define LSTM_2_HIDDEN_SHIFT -21 -#define LSTM_2_HIDDEN_OFFSET 74 -#define LSTM_2_OUTPUT_STATE_OFFSET 74 -#define LSTM_2_CELL_STATE_SHIFT -13 +#define LSTM_2_HIDDEN_OFFSET -24 +#define LSTM_2_DATA_OFFSET 128 +#define LSTM_2_OUTPUT_STATE_OFFSET -24 +#define LSTM_2_CELL_STATE_SHIFT -12 diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h index 020999f7..69bd9b74 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_forget_gate_bias[7] = {-15296, 3060, 33125, -4714, -14467, 22732, -4683}; +const int32_t lstm_2_forget_gate_bias[7] = {-25576, -22402, 15452, -12818, 31310, 27442, 4985}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h index 794b3962..558c2b81 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h index 5c1bb385..8ab020fa 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h @@ -1,12 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_input[108] = {-123, -88, 71, -16, -31, -9, 3, 71, 41, -43, 69, -51, -74, 31, -127, -3, - 63, 32, -23, -75, -95, 80, -96, -113, -4, -13, -6, -94, 6, 19, -80, 3, - 85, 73, -43, -17, 96, 67, 7, 31, -128, 117, -113, 24, -90, 113, -36, -13, - 108, -54, 115, -106, -3, -58, 122, 68, -52, -2, -46, -2, 9, 3, -122, 71, - 40, 111, 103, 2, 120, -53, 26, 125, 41, -94, 107, 51, 74, 7, 15, -27, - 9, 33, -117, -41, 50, 106, 66, -62, 117, -93, -107, -47, -96, 124, -6, -90, - 7, -3, 13, -18, 99, -56, -37, 104, 66, 99, -1, -32}; +const int8_t lstm_2_input[54] = {-77, 86, 115, -22, 114, -93, -95, 61, 100, 107, 33, -79, -11, -63, + -73, -90, -57, 40, 8, 23, 126, 82, 123, 105, 2, -102, -62, -78, + 111, 1, 49, 16, -77, 84, -4, 64, 27, 103, 102, 62, -101, -93, + -115, -103, -4, -35, -26, 36, 17, -78, 23, -106, 74, 4}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h index 9b4b6854..72a5db25 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_input_gate_bias[7] = {26212, 18263, 849, 17710, 22563, -17022, 4727}; +const int32_t lstm_2_input_gate_bias[7] = {-475, 24405, 15813, -18608, 27695, -32747, 5436}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h index c61ba1fb..868049cb 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h index 11ad38e5..abc8cbc9 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_input_to_cell_eff_bias[7] = {-24495, 6800, -29846, 4212, -18843, -18518, -34196}; +const int32_t lstm_2_input_to_cell_eff_bias[7] = {57106, -25319, -15180, 23851, 27030, 48140, -10254}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h index 466f6884..7b8dd0b6 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h @@ -1,8 +1,8 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_input_to_cell_w[42] = {-37, -62, 88, 15, -84, 31, 14, -46, 97, 33, 86, -70, -112, -75, - -45, -13, -22, 18, 26, -10, -117, -82, -45, 67, 106, 60, -127, -33, - -118, -78, -90, 51, 10, -23, 43, 52, 23, -64, 79, 75, -121, -7}; +const int8_t lstm_2_input_to_cell_w[42] = {68, 111, 95, 35, -39, 48, -62, -74, -91, -28, -32, -13, -14, -92, + -54, -121, -9, 2, -70, 15, 66, -6, 26, -70, -52, -50, 67, 33, + 3, 57, 76, 127, 110, -70, 74, -49, 35, -42, 56, -22, 13, 62}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h index b17afd3b..f4e40fad 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_input_to_forget_eff_bias[7] = {-19136, 14836, 19173, 16790, 12029, -6452, 14645}; +const int32_t lstm_2_input_to_forget_eff_bias[7] = {-34024, -60418, 47452, -16274, 15950, 23218, 30457}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h index 91613c16..a072ca22 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h @@ -1,8 +1,8 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_input_to_forget_w[42] = {127, -100, -44, 35, -24, -24, -26, 27, -21, -44, 116, 40, -43, 70, - -101, -126, 78, 13, -31, -40, 121, -67, 108, 77, -56, 93, 10, 74, - 56, 30, -104, 47, -14, 60, -120, -97, 19, -63, 84, 111, -53, 53}; +const int8_t lstm_2_input_to_forget_w[42] = { + -88, -23, 72, -119, -3, 95, -112, 13, -29, -78, 28, -119, 37, 45, 10, 67, 127, -36, 23, -116, 61, + -51, -54, 110, 55, 28, -116, 111, -109, -89, 106, 10, 2, -119, 89, -121, -33, 70, -105, 123, 45, 99}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h index f4d218a2..c0d06764 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_input_to_input_eff_bias[7] = {50916, 22359, -33583, -2258, 24355, -7934, 11255}; +const int32_t lstm_2_input_to_input_eff_bias[7] = {-219, 25045, 20293, -41904, 43567, -33259, 29756}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h index 2062c9ee..f0de95cf 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h @@ -1,8 +1,8 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_input_to_input_w[42] = {92, 7, -36, 85, 13, 32, -115, 84, -64, 124, 86, -83, 6, 127, - -106, -116, -80, -100, -124, -2, 46, 54, -56, -74, -20, 83, -113, -59, - 36, 87, -86, -31, -30, 5, 122, 91, -73, 96, 11, -7, 101, -77}; +const int8_t lstm_2_input_to_input_w[42] = {-105, -10, -45, 110, -23, 75, 17, -104, -35, 70, -66, 123, -113, 126, + 39, -127, -4, 114, 110, 0, -94, -24, -63, -111, -59, -13, 102, 116, + 7, -29, 40, 73, -113, -5, -85, 86, 124, 86, 125, 51, -92, -104}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h index d6401012..442871c2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_input_to_output_eff_bias[7] = {-23260, -64953, 21616, 44956, -55885, -8659, -2372}; +const int32_t lstm_2_input_to_output_eff_bias[7] = {15991, -35409, -4890, -55904, -28433, -10231, -19317}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h index aa7a60a4..bda0d883 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h @@ -1,8 +1,8 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_input_to_output_w[42] = {-13, -107, 95, -90, 57, -71, -122, -38, -121, -44, -68, -69, 109, -62, - 126, -53, 43, -26, -101, 123, 50, 82, -46, 39, -81, 41, -58, -52, - -101, 40, 13, -45, -15, -90, 19, -127, -6, -21, -38, 14, 25, -44}; +const int8_t lstm_2_input_to_output_w[42] = {-6, 11, 12, -127, -29, 19, -90, 49, 106, 71, -81, -124, 16, 1, + -121, -66, -79, -7, -102, -86, -40, 124, -88, 0, 82, -108, -7, 4, + 9, -75, 50, -120, -28, -63, 127, -56, -14, 41, 35, -98, -31, -105}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h index 2c9a5e45..e9794eb7 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_output_gate_bias[7] = {-6748, -5817, 4080, 26140, -28877, 22701, 6588}; +const int32_t lstm_2_output_gate_bias[7] = {31351, -26577, 27878, -31328, -16273, 1289, 2699}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h index 84c4b276..7ad90609 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h index 9184badf..8b6351db 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h @@ -1,12 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_output_ref[126] = { - 51, 90, 58, 50, 43, 44, 16, -6, 83, 7, 78, 50, 40, -40, 6, 50, 37, 127, 38, 49, -77, - 19, 45, 36, 127, 56, 31, -51, -1, 60, 19, 127, 46, 33, -55, 32, 69, 41, 127, 48, 45, -44, - 47, 65, 20, 127, 50, 53, -3, 32, 48, 46, 127, 39, 43, -51, -6, 75, 6, 127, 59, 27, -33, - 23, 72, -6, 117, 87, 40, 7, 16, 58, 28, 127, 52, 50, -52, 4, 72, 25, 127, 54, 46, -33, - -15, 81, 46, 102, 61, 34, -44, 20, 66, 34, 127, 56, 39, -16, -28, 81, -11, 105, 59, 35, -73, - -1, 56, 48, 127, 49, 36, -95, -19, 75, 27, 127, 56, 26, -90, 18, 71, 30, 121, 51, 50, -77}; +const int8_t lstm_2_output_ref[63] = { + 71, -80, -81, -19, 29, 0, -82, 101, -122, -107, -10, 65, 12, -83, 127, -39, -128, -7, 83, 33, -48, + 127, -92, -94, -13, 53, 40, -7, 115, -49, -128, -10, 86, 81, 6, 127, -76, -128, -6, 55, 44, 1, + 120, -128, -128, -3, 64, 30, -9, 127, -53, -128, 6, 80, 23, -2, 127, -63, -128, -10, 74, 61, 15}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h index b09faec6..c83812e9 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -int8_t lstm_2_output_state[14] = {74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74}; +int8_t lstm_2_output_state[7] = {-24, -24, -24, -24, -24, -24, -24}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h index 17722140..c6d005b4 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h index 6bc42d53..b01c0ffe 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h index 48195152..9b9330bf 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_2_recurrent_input_to_cell_w[49] = {111, -117, 121, -81, -85, -44, -107, 118, 83, 51, 69, 16, 110, - 60, -36, 114, -55, -87, -60, -115, -8, -98, -91, -45, -94, -10, - 49, -127, 17, 112, 126, -21, 86, -80, 82, 44, 90, -20, -38, - 55, -67, 39, 125, 104, 80, 102, -110, -4, -5}; +const int8_t lstm_2_recurrent_input_to_cell_w[49] = { + -113, -40, -36, -21, -5, 30, 42, -66, -108, -52, -106, 39, -87, -24, -57, -52, 55, + 83, 16, 8, -92, 53, -119, 83, -60, 127, -90, 39, -87, 45, 91, -6, -33, 79, + -26, -82, 102, 97, -99, -97, 84, -55, 31, -42, -63, -94, 62, 83, 19}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h index 8a54759a..f309730f 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_2_recurrent_input_to_forget_w[49] = { - 46, 72, -126, -19, -98, -127, -83, 6, 74, 110, 109, 69, -82, 71, 100, 84, -37, - -88, 60, -86, -111, 14, -64, -95, 41, 87, -27, -67, -110, 59, -66, 88, -62, 61, - 121, 60, -52, -17, 113, -102, 53, -103, 127, -56, -57, 45, 37, 125, 32}; + -97, 94, -76, 43, 27, 4, 5, -65, 116, -99, -15, 38, 55, -124, -30, 52, 41, + -72, -31, 107, -109, 125, 112, -88, 36, 44, -121, 34, 21, -37, -20, -49, 68, -38, + 127, 1, -85, 60, -9, -117, -104, 34, -84, -48, 126, -33, 6, -105, 105}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h index f5ff6047..a1426c08 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_2_recurrent_input_to_input_w[49] = { - -41, 46, -126, -116, -34, -18, -107, -105, -6, -68, 56, 36, -2, 113, 76, -65, 67, - -123, -125, 74, 92, 31, 116, -108, 114, -123, -60, -59, 29, -23, 4, -116, 114, -94, - -79, 79, -13, -70, -99, -17, 67, 120, 127, 41, -71, 59, 80, 53, -111}; + -5, -98, 108, 28, -6, 114, 11, -61, 6, -75, -115, -115, -21, -95, 43, 36, 72, + 33, 91, 10, 52, 58, 19, 47, -20, 17, -68, -25, 25, -87, -5, 125, 68, -42, + 75, -16, 28, -13, 23, 50, 111, -114, 58, -23, 127, 101, 43, -107, -28}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h index 22ad7793..5d405188 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_2_recurrent_input_to_output_w[49] = { - -38, 81, -44, -57, 126, -127, 50, -3, 54, 71, 74, -51, -9, -3, -122, 51, -26, - 92, -18, -124, -91, 73, -68, 109, -33, 24, -55, -48, 124, -83, 12, -113, 68, 23, - 32, 88, -115, 113, -109, 28, 61, -37, -31, -62, -83, 74, 73, -23, -123}; + 67, -45, -7, 83, 73, 93, 116, 86, 40, 19, 114, 121, -34, 106, -76, -75, -99, + 24, -8, -71, 109, -4, -1, -125, -21, -125, 96, 33, 29, -113, -76, -51, -107, -28, + -120, -10, -10, -91, 40, -5, 117, -70, -94, -122, 127, 10, 102, 82, 29}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h index 08190320..22be2cd4 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_recurrent_to_cell_eff_bias[7] = {14948, -37518, 18278, 30784, -23828, -7622, -21608}; +const int32_t lstm_2_recurrent_to_cell_eff_bias[7] = {-3432, -9696, -936, 792, 1512, -1200, -96}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h index 5398098e..1c9768d0 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_recurrent_to_forget_eff_bias[7] = {24790, -26418, 5772, 8214, -6734, 3552, -18722}; +const int32_t lstm_2_recurrent_to_forget_eff_bias[7] = {0, -2256, -1008, 3408, 1728, -5280, -792}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h index 56a21435..7d6bc02b 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_recurrent_to_input_eff_bias[7] = {29304, -1776, 296, 6586, 12210, -4958, -13172}; +const int32_t lstm_2_recurrent_to_input_eff_bias[7] = {3648, -11424, 8088, 672, 3816, 1656, 4104}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h index 483dcb45..01aefc8d 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_2_recurrent_to_output_eff_bias[7] = {666, -9842, 17612, -148, -4662, -2146, 12950}; +const int32_t lstm_2_recurrent_to_output_eff_bias[7] = {9120, 10848, -4704, -3528, -11184, -696, 3216}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h index ecd25010..f1ab6ed2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #include "cell_gate_bias_data.h" #include "cell_norm_coeff_data.h" #include "cell_state_data.h" diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h index cd215406..6b100d83 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_cell_gate_bias[3] = {-15651, -31126, 21566}; +const int32_t lstm_one_time_step_cell_gate_bias[3] = {322, -28994, -4077}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h index 7a966059..1803b8ce 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h index e81e6523..928b5312 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h index 53071eec..784f00bf 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h index e24a3ddf..e63ba379 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h index 05729680..bf8c40aa 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h index f18c7e93..d719d33b 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #define LSTM_ONE_TIME_STEP_BUFFER_SIZE 9 #define LSTM_ONE_TIME_STEP_INPUT_BATCHES 3 @@ -10,24 +10,29 @@ #define LSTM_ONE_TIME_STEP_TIME_MAJOR 0 #define LSTM_ONE_TIME_STEP_IN_ACTIVATION_MIN -32768 #define LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX 32767 -#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER 1081402240 -#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT -2 -#define LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER 2077292928 +#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER 2118812544 +#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT -3 +#define LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER 2143159040 #define LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT -3 -#define LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER 2082823168 -#define LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT -3 -#define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER 2119604096 +#define LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER 1079340672 +#define LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT -2 +#define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER 2144143488 #define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT -3 -#define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER 1242172928 +#define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER 1158760960 #define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT -3 -#define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER 1241717376 +#define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER 1197046528 #define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT -3 -#define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER 1075447936 +#define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER 1244768512 #define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT -3 -#define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER 1153133824 +#define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER 1216274304 #define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT -3 -#define LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER 1833641157 +#define LSTM_ONE_TIME_STEP_FORGET_MULTIPLIER 1073741824 +#define LSTM_ONE_TIME_STEP_FORGET_SHIFT -14 +#define LSTM_ONE_TIME_STEP_INPUT_MULTIPLIER 1073741824 +#define LSTM_ONE_TIME_STEP_INPUT_SHIFT -13 +#define LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER 1828318324 #define LSTM_ONE_TIME_STEP_HIDDEN_SHIFT -21 -#define LSTM_ONE_TIME_STEP_HIDDEN_OFFSET 106 -#define LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET 106 -#define LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT -15 +#define LSTM_ONE_TIME_STEP_HIDDEN_OFFSET -16 +#define LSTM_ONE_TIME_STEP_DATA_OFFSET 128 +#define LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET -16 +#define LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT -16 diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h index 99818b75..3ec7b9ee 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_forget_gate_bias[3] = {25787, -14479, 16570}; +const int32_t lstm_one_time_step_forget_gate_bias[3] = {-19492, -713, -23531}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h index 9bfd0b9a..27e6baf2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h index 1b349352..165ed208 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h @@ -1,9 +1,10 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_one_time_step_input[66] = { - -75, -122, 117, 85, -108, -63, 57, -6, -128, -27, 51, -48, 35, -103, 6, -66, -35, 64, 82, 24, 29, -39, - -53, 47, -74, 87, 124, 102, -93, -8, -128, 54, 45, -26, -107, -21, -60, -40, -45, 96, -50, 75, 36, -73, - -93, -88, -21, -32, -112, -113, 73, -88, -101, 93, -70, 6, -36, 67, 86, 42, 66, -70, 10, 8, -63, -38}; + 83, 52, 7, 91, -102, 29, 52, -108, -93, 73, -46, 27, 120, -117, -57, -86, -124, + -112, -36, 26, -127, -39, -97, -59, -111, 98, 8, 121, 30, 51, 18, 15, -92, -41, + -71, 42, -96, 66, -18, -127, -91, 59, -29, 71, 96, -44, 83, 63, 51, 55, -8, + -73, -63, -53, -66, -14, -113, 104, -118, -116, 50, -94, 94, -23, 88, -119}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h index a931d2c1..0623f3b0 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_input_gate_bias[3] = {1420, -28495, -28358}; +const int32_t lstm_one_time_step_input_gate_bias[3] = {-33100, 3078, 31928}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h index d84a6460..f5bb6ea9 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h index 6a3057d8..70557c6c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_input_to_cell_eff_bias[3] = {-30243, 17130, 70846}; +const int32_t lstm_one_time_step_input_to_cell_eff_bias[3] = {16450, 17086, -36717}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h index 39b5df84..fa886aee 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_one_time_step_input_to_cell_w[66] = { - 1, 110, 28, -87, 44, -118, -44, 88, -45, 85, -24, -33, 47, -85, 106, 10, -86, -121, -4, 86, -26, -46, - -78, -109, 50, 91, 69, 110, -90, 86, 22, 102, -117, 79, 111, -114, 106, -127, 47, 106, 99, 30, -116, 20, - 25, -19, 125, -91, 12, 82, -41, 24, 84, 108, -108, -102, -26, -117, 19, 59, 117, 127, -101, 22, 93, 93}; + 42, -24, 36, 70, 39, 63, 95, -98, -113, 13, 43, 47, -107, -30, -117, -81, 116, 5, 122, -47, -37, 89, + -7, 74, -4, 114, -115, 119, 55, -24, -73, -97, 52, -26, 127, 32, 69, 76, 105, 14, -116, 27, -87, 45, + -104, -61, -110, -30, 44, -102, -93, 103, 17, 100, 14, -27, -100, 9, 13, 55, -50, 28, -1, -91, 19, 112}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h index f1297854..d6d22ab5 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_input_to_forget_eff_bias[3] = {-2117, -71439, 29626}; +const int32_t lstm_one_time_step_input_to_forget_eff_bias[3] = {-49060, -20937, 76309}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h index 8924c8c7..33fb5aec 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_one_time_step_input_to_forget_w[66] = { - -10, -66, -18, 87, -5, -12, 23, -75, -72, -13, 37, 18, 30, -28, -63, -20, 13, 117, 55, -103, -16, -97, - -40, -124, 47, -78, -55, 104, -115, 86, -65, 39, -49, 56, 97, -62, -1, -11, 94, 67, -104, -110, -108, -113, - 6, -71, 85, -84, 9, 48, 39, -65, -37, 9, -47, 123, -84, 127, -114, 90, 69, -4, -48, 0, -21, 72}; + 34, -119, 10, 10, -70, -40, 37, 24, 124, -124, -64, 37, -107, 56, 84, -57, -33, 8, 10, 22, -98, 25, + -9, -47, -95, -21, 38, -124, -102, 12, 69, -48, -127, 96, -34, 39, 21, -104, 85, 58, 89, -66, 53, 59, + -34, 82, -62, 4, 44, 31, -22, -23, 34, 69, 0, 4, 114, 83, -72, 110, 91, 103, 94, -41, 117, 54}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h index 9ffc4240..5a9e26f3 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_input_to_input_eff_bias[3] = {23436, -59215, -13894}; +const int32_t lstm_one_time_step_input_to_input_eff_bias[3] = {-12620, -12154, 952}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h index 5182e6c3..6e827f5c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h @@ -1,9 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_one_time_step_input_to_input_w[66] = { - 127, 42, 67, 72, -34, 103, 27, 64, -65, -23, -122, -105, -49, -63, 71, 30, -9, -5, 103, -16, 42, -85, - 40, 37, -115, -19, 45, -49, 72, -26, 37, -122, 99, 16, 31, -94, -97, 92, -76, -111, -6, -82, -9, 97, - 8, -23, 42, 108, 13, 13, 98, 50, 40, -52, -124, 7, 17, 117, 66, -84, -88, -97, -70, -4, 22, 54}; + -51, -22, 74, 11, -47, -57, 65, -89, 98, 7, 33, -95, -94, 116, -74, 11, 126, -1, 107, 29, -33, 46, + 72, -127, 99, -53, 51, -15, 5, -66, 14, 14, -76, -105, -48, -114, 116, 38, -1, 96, 43, -94, -91, 123, + -103, -77, 64, -65, -94, -18, -88, -75, 25, 90, -125, 74, 52, -29, -7, 7, 56, 54, -83, -66, 75, 91}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h index cd6d40dc..39cf6ba3 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_input_to_output_eff_bias[3] = {48450, 36258, -48261}; +const int32_t lstm_one_time_step_input_to_output_eff_bias[3] = {28514, 56537, 50815}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h index b21f223b..e77c7456 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h @@ -1,10 +1,9 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include const int8_t lstm_one_time_step_input_to_output_w[66] = { - 100, -56, 93, -52, 62, 127, 102, -91, 33, -120, -107, 81, -94, 104, -120, 6, -126, - -47, 112, 92, 13, 55, 84, 127, -91, 6, 53, -19, -33, -120, 26, 124, 25, 78, - 72, -122, -79, -62, 102, -6, 55, 125, -96, -53, -39, 1, -42, 107, 44, -56, -40, - 41, -104, 16, 35, -107, -124, 118, 91, -111, -127, 67, 52, -66, -20, 75}; + 74, 127, -51, 71, 113, -78, -47, -24, -11, -61, 74, -44, 65, -70, -50, -88, 102, -6, -78, 97, 78, -22, + -53, 100, -54, 54, 74, 98, 22, -106, -78, 100, 114, 123, -1, 0, 63, 75, 14, -51, 7, 19, -50, 16, + 127, 7, -44, -68, 65, 88, 28, 23, 33, 125, 94, 51, 24, -108, 93, 63, -22, -75, 52, -22, -104, -120}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h index 2e4918e1..e7a83a6c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_output_gate_bias[3] = {27074, 11170, -24069}; +const int32_t lstm_one_time_step_output_gate_bias[3] = {6626, -5671, 11135}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h index 2a938b16..62ab428a 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h index 93ed5692..a2ae0559 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_one_time_step_output_ref[9] = {-70, 122, 127, -63, 125, 127, 25, 110, 127}; +const int8_t lstm_one_time_step_output_ref[9] = {29, 98, -128, 38, 54, -70, 127, -63, -111}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h index 8f4631e7..3d66fa1e 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -int8_t lstm_one_time_step_output_state[9] = {106, 106, 106, 106, 106, 106, 106, 106, 106}; +int8_t lstm_one_time_step_output_state[9] = {-16, -16, -16, -16, -16, -16, -16, -16, -16}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h index 1c3140ec..7a90e907 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h index 62246b71..4266770c 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h index 9b4a48e5..92e79347 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_one_time_step_recurrent_input_to_cell_w[9] = {-42, -127, -4, -88, -56, 107, -109, -112, 11}; +const int8_t lstm_one_time_step_recurrent_input_to_cell_w[9] = {-91, -61, 124, 5, -127, -65, -103, -113, -53}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h index 41d12737..7b8c939a 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_one_time_step_recurrent_input_to_forget_w[9] = {-120, -120, -31, -90, 20, -54, 58, 127, -17}; +const int8_t lstm_one_time_step_recurrent_input_to_forget_w[9] = {-102, -67, -12, 7, -127, -96, 96, 16, -57}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h index cfae7c8d..2623531a 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_one_time_step_recurrent_input_to_input_w[9] = {107, 126, 64, 47, 16, -71, 123, -127, -118}; +const int8_t lstm_one_time_step_recurrent_input_to_input_w[9] = {55, 103, 98, -127, -114, -93, 31, -36, 122}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h index f1c194e9..72364ff6 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int8_t lstm_one_time_step_recurrent_input_to_output_w[9] = {-90, -45, 111, 100, -20, 66, 27, -127, 67}; +const int8_t lstm_one_time_step_recurrent_input_to_output_w[9] = {-127, -81, 94, -112, -25, -70, -68, 1, -16}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h index 707b1408..0e17501e 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_recurrent_to_cell_eff_bias[3] = {18338, 3922, 22260}; +const int32_t lstm_one_time_step_recurrent_to_cell_eff_bias[3] = {-448, -2992, -4304}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h index 181eb477..76b468d1 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_recurrent_to_forget_eff_bias[3] = {28726, 13144, -17808}; +const int32_t lstm_one_time_step_recurrent_to_forget_eff_bias[3] = {-2896, -3456, 880}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h index 66278394..8121c87e 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_recurrent_to_input_eff_bias[3] = {-31482, 848, 12932}; +const int32_t lstm_one_time_step_recurrent_to_input_eff_bias[3] = {4096, -5344, 1872}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h index 491b3bed..a0d73385 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h @@ -1,6 +1,6 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #pragma once #include -const int32_t lstm_one_time_step_recurrent_to_output_eff_bias[3] = {2544, -15476, 3498}; +const int32_t lstm_one_time_step_recurrent_to_output_eff_bias[3] = {-1824, -3312, -1328}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h index ecd25010..f1ab6ed2 100644 --- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h +++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h @@ -1,5 +1,5 @@ -// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0). -// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4. +// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0). +// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None. #include "cell_gate_bias_data.h" #include "cell_norm_coeff_data.h" #include "cell_state_data.h" diff --git a/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c b/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c index f17a7f0e..ca21593e 100644 --- a/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c @@ -473,7 +473,7 @@ void ds_cnn_l_s8_inference(void) bias_dims.c = in_out_dim_1.c; #if defined(ARM_MATH_MVEI) - arm_vector_sum_s8(ctx.buf, conv_filter_dims.n, in_out_dim_1.c, ds_cnn_l_layer_14_fully_connected_weights); + arm_vector_sum_s8(ctx.buf, conv_filter_dims.n, in_out_dim_1.c, ds_cnn_l_layer_14_fully_connected_weights, 1, NULL); #endif status |= arm_fully_connected_s8(&ctx, diff --git a/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c b/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c index c5e0713c..be5be41d 100644 --- a/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -70,7 +70,7 @@ void fully_connected_arm_fully_connected_s8(void) #if defined(ARM_MATH_MVEI) int32_t *buf = ctx.buf; - TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data)); + TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL)); #endif arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx, @@ -133,7 +133,7 @@ void fully_connected_mve_0_arm_fully_connected_s8(void) #if defined(ARM_MATH_MVEI) int32_t *buf = ctx.buf; - TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data)); + TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL)); #endif arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx, @@ -195,7 +195,7 @@ void fully_connected_mve_1_arm_fully_connected_s8(void) #if defined(ARM_MATH_MVEI) int32_t *buf = ctx.buf; - TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data)); + TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL)); #endif arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx, @@ -268,7 +268,7 @@ void fully_connected_null_bias_0_arm_fully_connected_s8(void) #if defined(ARM_MATH_MVEI) int32_t *buf = ctx.buf; - TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data)); + TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL)); #endif arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx, @@ -330,7 +330,7 @@ void fully_connected_out_activation_arm_fully_connected_s8(void) #if defined(ARM_MATH_MVEI) int32_t *buf = ctx.buf; - TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data)); + TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL)); #endif arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx, diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c deleted file mode 100644 index d90da214..00000000 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c +++ /dev/null @@ -1,328 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include -#include - -#include "../TestData/lstm_1/test_data.h" -#include "../TestData/lstm_2/test_data.h" -#include "../TestData/lstm_one_time_step/test_data.h" -#include "../Utils/validate.h" - -#if (LSTM_2_BUFFER_SIZE < LSTM_1_BUFFER_SIZE) || (LSTM_2_BUFFER_SIZE < LSTM_ONE_TIME_STEP_BUFFER_SIZE) - #error "Test buffers too small." -#endif - -// Update the buffer size if adding a unit test with larger buffer. -#define LARGEST_BUFFER_SIZE LSTM_2_BUFFER_SIZE - -int16_t buffer0[LARGEST_BUFFER_SIZE]; -int16_t buffer1[LARGEST_BUFFER_SIZE]; -int16_t buffer2[LARGEST_BUFFER_SIZE]; -int16_t buffer3[LARGEST_BUFFER_SIZE]; - -void lstm_1_arm_lstm_unidirectional_s16_s8(void) -{ - const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; - const bool time_major = (bool)LSTM_1_TIME_MAJOR; - - int8_t output[LSTM_1_DST_SIZE] = {0}; - - cmsis_nn_lstm_context scratch_buffers = {}; - cmsis_nn_lstm_dims lstm_dims = {}; - cmsis_nn_lstm_params lstm = {}; - - scratch_buffers.input_gate = buffer0; - scratch_buffers.forget_gate = buffer1; - scratch_buffers.cell_gate = buffer2; - scratch_buffers.output_gate = buffer3; - - lstm_dims.num_batches = LSTM_1_INPUT_BATCHES; - lstm_dims.num_inputs = LSTM_1_NUMBER_INPUTS; - lstm_dims.max_time = LSTM_1_TIME_STEPS; - lstm_dims.num_outputs = LSTM_1_NUMBER_UNITS; - - lstm.time_major = time_major; - lstm.input_to_input_scaling.multiplier = LSTM_1_IN_TO_INPUT_MULTIPLIER; - lstm.input_to_input_scaling.shift = LSTM_1_IN_TO_INPUT_SHIFT; - lstm.input_to_forget_scaling.multiplier = LSTM_1_IN_TO_FORGET_MULTIPLIER; - lstm.input_to_forget_scaling.shift = LSTM_1_IN_TO_FORGET_SHIFT; - lstm.input_to_cell_scaling.multiplier = LSTM_1_IN_TO_CELL_MULTIPLIER; - lstm.input_to_cell_scaling.shift = LSTM_1_IN_TO_CELL_SHIFT; - lstm.input_to_output_scaling.multiplier = LSTM_1_IN_TO_OUTPUT_MULTIPLIER; - lstm.input_to_output_scaling.shift = LSTM_1_IN_TO_OUTPUT_SHIFT; - - lstm.recurrent_to_input_scaling.multiplier = LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER; - lstm.recurrent_to_input_scaling.shift = LSTM_1_RECURRENT_TO_INPUT_SHIFT; - lstm.recurrent_to_cell_scaling.multiplier = LSTM_1_RECURRENT_TO_CELL_MULTIPLIER; - lstm.recurrent_to_cell_scaling.shift = LSTM_1_RECURRENT_TO_CELL_SHIFT; - lstm.recurrent_to_forget_scaling.multiplier = LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER; - lstm.recurrent_to_forget_scaling.shift = LSTM_1_RECURRENT_TO_FORGET_SHIFT; - lstm.recurrent_to_output_scaling.multiplier = LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER; - lstm.recurrent_to_output_scaling.shift = LSTM_1_RECURRENT_TO_OUTPUT_SHIFT; - - lstm.i2i_effective_bias = lstm_1_input_to_input_eff_bias; - lstm.i2f_effective_bias = lstm_1_input_to_forget_eff_bias; - lstm.i2c_effective_bias = lstm_1_input_to_cell_eff_bias; - lstm.i2o_effective_bias = lstm_1_input_to_output_eff_bias; - - lstm.r2i_effective_bias = lstm_1_recurrent_to_input_eff_bias; - lstm.r2f_effective_bias = lstm_1_recurrent_to_forget_eff_bias; - lstm.r2c_effective_bias = lstm_1_recurrent_to_cell_eff_bias; - lstm.r2o_effective_bias = lstm_1_recurrent_to_output_eff_bias; - - lstm.input_gate_bias = lstm_1_input_gate_bias; - lstm.forget_gate_bias = lstm_1_forget_gate_bias; - lstm.cell_gate_bias = lstm_1_cell_gate_bias; - lstm.output_gate_bias = lstm_1_output_gate_bias; - - lstm.activation.min = LSTM_1_IN_ACTIVATION_MIN; - lstm.activation.max = LSTM_1_IN_ACTIVATION_MAX; - - lstm.hidden_scaling.multiplier = LSTM_1_HIDDEN_MULTIPLIER; - lstm.hidden_scaling.shift = LSTM_1_HIDDEN_SHIFT; - - lstm.hidden_offset = LSTM_1_HIDDEN_OFFSET; - - lstm.cell_state_shift = LSTM_1_CELL_STATE_SHIFT; - lstm.output_state_offset = LSTM_1_OUTPUT_STATE_OFFSET; - - const int8_t *input_data = lstm_1_input; - const int8_t *output_ref = lstm_1_output_ref; - const int32_t output_ref_size = LSTM_1_DST_SIZE; - - arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers, - input_data, - &lstm_dims, - lstm_1_input_to_input_w, - lstm_1_input_to_forget_w, - lstm_1_input_to_cell_w, - lstm_1_input_to_output_w, - lstm_1_recurrent_input_to_input_w, - lstm_1_recurrent_input_to_forget_w, - lstm_1_recurrent_input_to_cell_w, - lstm_1_recurrent_input_to_output_w, - lstm_1_cell_to_input, - lstm_1_cell_to_forget, - lstm_1_cell_to_output, - lstm_1_projection_weights, - &lstm, - lstm_1_output_state, - lstm_1_cell_state, - output); - - TEST_ASSERT_EQUAL(expected, result); - TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); -} - -void lstm_2_arm_lstm_unidirectional_s16_s8(void) -{ - const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; - const bool time_major = (bool)LSTM_2_TIME_MAJOR; - - int8_t output[LSTM_2_DST_SIZE] = {0}; - - cmsis_nn_lstm_context scratch_buffers = {}; - cmsis_nn_lstm_dims lstm_dims = {}; - cmsis_nn_lstm_params lstm = {}; - - scratch_buffers.input_gate = buffer0; - scratch_buffers.forget_gate = buffer1; - scratch_buffers.cell_gate = buffer2; - scratch_buffers.output_gate = buffer3; - - lstm_dims.num_batches = LSTM_2_INPUT_BATCHES; - lstm_dims.num_inputs = LSTM_2_NUMBER_INPUTS; - lstm_dims.max_time = LSTM_2_TIME_STEPS; - lstm_dims.num_outputs = LSTM_2_NUMBER_UNITS; - - lstm.time_major = time_major; - lstm.input_to_input_scaling.multiplier = LSTM_2_IN_TO_INPUT_MULTIPLIER; - lstm.input_to_input_scaling.shift = LSTM_2_IN_TO_INPUT_SHIFT; - lstm.input_to_forget_scaling.multiplier = LSTM_2_IN_TO_FORGET_MULTIPLIER; - lstm.input_to_forget_scaling.shift = LSTM_2_IN_TO_FORGET_SHIFT; - lstm.input_to_cell_scaling.multiplier = LSTM_2_IN_TO_CELL_MULTIPLIER; - lstm.input_to_cell_scaling.shift = LSTM_2_IN_TO_CELL_SHIFT; - lstm.input_to_output_scaling.multiplier = LSTM_2_IN_TO_OUTPUT_MULTIPLIER; - lstm.input_to_output_scaling.shift = LSTM_2_IN_TO_OUTPUT_SHIFT; - - lstm.recurrent_to_input_scaling.multiplier = LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER; - lstm.recurrent_to_input_scaling.shift = LSTM_2_RECURRENT_TO_INPUT_SHIFT; - lstm.recurrent_to_cell_scaling.multiplier = LSTM_2_RECURRENT_TO_CELL_MULTIPLIER; - lstm.recurrent_to_cell_scaling.shift = LSTM_2_RECURRENT_TO_CELL_SHIFT; - lstm.recurrent_to_forget_scaling.multiplier = LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER; - lstm.recurrent_to_forget_scaling.shift = LSTM_2_RECURRENT_TO_FORGET_SHIFT; - lstm.recurrent_to_output_scaling.multiplier = LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER; - lstm.recurrent_to_output_scaling.shift = LSTM_2_RECURRENT_TO_OUTPUT_SHIFT; - - lstm.i2i_effective_bias = lstm_2_input_to_input_eff_bias; - lstm.i2f_effective_bias = lstm_2_input_to_forget_eff_bias; - lstm.i2c_effective_bias = lstm_2_input_to_cell_eff_bias; - lstm.i2o_effective_bias = lstm_2_input_to_output_eff_bias; - - lstm.r2i_effective_bias = lstm_2_recurrent_to_input_eff_bias; - lstm.r2f_effective_bias = lstm_2_recurrent_to_forget_eff_bias; - lstm.r2c_effective_bias = lstm_2_recurrent_to_cell_eff_bias; - lstm.r2o_effective_bias = lstm_2_recurrent_to_output_eff_bias; - - lstm.input_gate_bias = lstm_2_input_gate_bias; - lstm.forget_gate_bias = lstm_2_forget_gate_bias; - lstm.cell_gate_bias = lstm_2_cell_gate_bias; - lstm.output_gate_bias = lstm_2_output_gate_bias; - - lstm.activation.min = LSTM_2_IN_ACTIVATION_MIN; - lstm.activation.max = LSTM_2_IN_ACTIVATION_MAX; - - lstm.hidden_scaling.multiplier = LSTM_2_HIDDEN_MULTIPLIER; - lstm.hidden_scaling.shift = LSTM_2_HIDDEN_SHIFT; - - lstm.hidden_offset = LSTM_2_HIDDEN_OFFSET; - - lstm.cell_state_shift = LSTM_2_CELL_STATE_SHIFT; - lstm.output_state_offset = LSTM_2_OUTPUT_STATE_OFFSET; - - const int8_t *input_data = lstm_2_input; - const int8_t *output_ref = lstm_2_output_ref; - const int32_t output_ref_size = LSTM_2_DST_SIZE; - - arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers, - input_data, - &lstm_dims, - lstm_2_input_to_input_w, - lstm_2_input_to_forget_w, - lstm_2_input_to_cell_w, - lstm_2_input_to_output_w, - lstm_2_recurrent_input_to_input_w, - lstm_2_recurrent_input_to_forget_w, - lstm_2_recurrent_input_to_cell_w, - lstm_2_recurrent_input_to_output_w, - lstm_2_cell_to_input, - lstm_2_cell_to_forget, - lstm_2_cell_to_output, - lstm_2_projection_weights, - &lstm, - lstm_2_output_state, - lstm_2_cell_state, - output); - - TEST_ASSERT_EQUAL(expected, result); - TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); -} - -void lstm_one_time_step_arm_lstm_unidirectional_s16_s8(void) -{ - const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; - const bool time_major = (bool)LSTM_ONE_TIME_STEP_TIME_MAJOR; - - int8_t output[LSTM_ONE_TIME_STEP_DST_SIZE] = {0}; - - int16_t cell_state[LSTM_ONE_TIME_STEP_DST_SIZE]; - int8_t output_state[LSTM_ONE_TIME_STEP_DST_SIZE]; - - memcpy(output_state, lstm_one_time_step_output_state, LSTM_ONE_TIME_STEP_DST_SIZE * sizeof(int8_t)); - memcpy(cell_state, lstm_one_time_step_cell_state, LSTM_ONE_TIME_STEP_DST_SIZE * sizeof(int16_t)); - - cmsis_nn_lstm_context scratch_buffers = {}; - cmsis_nn_lstm_dims lstm_dims = {}; - cmsis_nn_lstm_params lstm = {}; - - scratch_buffers.input_gate = buffer0; - scratch_buffers.forget_gate = buffer1; - scratch_buffers.cell_gate = buffer2; - scratch_buffers.output_gate = buffer3; - - lstm_dims.num_batches = LSTM_ONE_TIME_STEP_INPUT_BATCHES; - lstm_dims.num_inputs = LSTM_ONE_TIME_STEP_NUMBER_INPUTS; - lstm_dims.max_time = LSTM_ONE_TIME_STEP_TIME_STEPS; - lstm_dims.num_outputs = LSTM_ONE_TIME_STEP_NUMBER_UNITS; - - lstm.time_major = time_major; - lstm.input_to_input_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER; - lstm.input_to_input_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT; - lstm.input_to_forget_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER; - lstm.input_to_forget_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT; - lstm.input_to_cell_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER; - lstm.input_to_cell_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT; - lstm.input_to_output_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER; - lstm.input_to_output_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT; - - lstm.recurrent_to_input_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER; - lstm.recurrent_to_input_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT; - lstm.recurrent_to_cell_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER; - lstm.recurrent_to_cell_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT; - lstm.recurrent_to_forget_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER; - lstm.recurrent_to_forget_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT; - lstm.recurrent_to_output_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER; - lstm.recurrent_to_output_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT; - - lstm.i2i_effective_bias = lstm_one_time_step_input_to_input_eff_bias; - lstm.i2f_effective_bias = lstm_one_time_step_input_to_forget_eff_bias; - lstm.i2c_effective_bias = lstm_one_time_step_input_to_cell_eff_bias; - lstm.i2o_effective_bias = lstm_one_time_step_input_to_output_eff_bias; - - lstm.r2i_effective_bias = lstm_one_time_step_recurrent_to_input_eff_bias; - lstm.r2f_effective_bias = lstm_one_time_step_recurrent_to_forget_eff_bias; - lstm.r2c_effective_bias = lstm_one_time_step_recurrent_to_cell_eff_bias; - lstm.r2o_effective_bias = lstm_one_time_step_recurrent_to_output_eff_bias; - - lstm.input_gate_bias = lstm_one_time_step_input_gate_bias; - lstm.forget_gate_bias = lstm_one_time_step_forget_gate_bias; - lstm.cell_gate_bias = lstm_one_time_step_cell_gate_bias; - lstm.output_gate_bias = lstm_one_time_step_output_gate_bias; - - lstm.activation.min = LSTM_ONE_TIME_STEP_IN_ACTIVATION_MIN; - lstm.activation.max = LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX; - - lstm.hidden_scaling.multiplier = LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER; - lstm.hidden_scaling.shift = LSTM_ONE_TIME_STEP_HIDDEN_SHIFT; - - lstm.hidden_offset = LSTM_ONE_TIME_STEP_HIDDEN_OFFSET; - - lstm.cell_state_shift = LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT; - lstm.output_state_offset = LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET; - - const int8_t *input_data = lstm_one_time_step_input; - const int8_t *output_ref = lstm_one_time_step_output_ref; - const int32_t output_ref_size = LSTM_ONE_TIME_STEP_DST_SIZE; - - arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers, - input_data, - &lstm_dims, - lstm_one_time_step_input_to_input_w, - lstm_one_time_step_input_to_forget_w, - lstm_one_time_step_input_to_cell_w, - lstm_one_time_step_input_to_output_w, - lstm_one_time_step_recurrent_input_to_input_w, - lstm_one_time_step_recurrent_input_to_forget_w, - lstm_one_time_step_recurrent_input_to_cell_w, - lstm_one_time_step_recurrent_input_to_output_w, - lstm_one_time_step_cell_to_input, - lstm_one_time_step_cell_to_forget, - lstm_one_time_step_cell_to_output, - lstm_one_time_step_projection_weights, - &lstm, - lstm_one_time_step_output_state, - lstm_one_time_step_cell_state, - output); - - TEST_ASSERT_EQUAL(expected, result); - TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); -} diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt similarity index 67% rename from Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt rename to Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt index e51e4e9f..ea73e4ff 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2010-2022, 2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -16,8 +16,8 @@ # limitations under the License. # -add_cmsis_nn_unit_test_executable(test_arm_lstm_unidirectional_s16_s8) +add_cmsis_nn_unit_test_executable(test_arm_lstm_unidirectional_s8) -target_sources(test_arm_lstm_unidirectional_s16_s8 PRIVATE - Unity/unity_test_arm_lstm_unidirectional_s16_s8.c - Unity/TestRunner/unity_test_arm_lstm_unidirectional_s16_s8_runner.c) +target_sources(test_arm_lstm_unidirectional_s8 PRIVATE + Unity/unity_test_arm_lstm_unidirectional_s8.c + Unity/TestRunner/unity_test_arm_lstm_unidirectional_s8_runner.c) diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c similarity index 69% rename from Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c rename to Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c index 05276f38..28dadf9b 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -22,7 +22,7 @@ #include #include -#include "../test_arm_lstm_unidirectional_s16_s8.c" +#include "../test_arm_lstm_unidirectional_s8.c" #include "unity.h" #ifdef USING_FVP_CORSTONE_300 @@ -44,9 +44,6 @@ void setUp(void) */ void tearDown(void) {} -void test_lstm_1_arm_lstm_unidirectional_s16_s8(void) { lstm_1_arm_lstm_unidirectional_s16_s8(); } -void test_lstm_2_arm_lstm_unidirectional_s16_s8(void) { lstm_2_arm_lstm_unidirectional_s16_s8(); } -void test_lstm_one_time_step_arm_lstm_unidirectional_s16_s8(void) -{ - lstm_one_time_step_arm_lstm_unidirectional_s16_s8(); -} +void test_lstm_1_arm_lstm_unidirectional_s8(void) { lstm_1_arm_lstm_unidirectional_s8(); } +void test_lstm_2_arm_lstm_unidirectional_s8(void) { lstm_2_arm_lstm_unidirectional_s8(); } +void test_lstm_one_time_step_arm_lstm_unidirectional_s8(void) { lstm_one_time_step_arm_lstm_unidirectional_s8(); } diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c new file mode 100644 index 00000000..1d132686 --- /dev/null +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c @@ -0,0 +1,495 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "../TestData/lstm_1/test_data.h" +#include "../TestData/lstm_2/test_data.h" +#include "../TestData/lstm_one_time_step/test_data.h" +#include "../Utils/validate.h" + +#if (LSTM_2_BUFFER_SIZE > LSTM_1_BUFFER_SIZE) || (LSTM_1_BUFFER_SIZE < LSTM_ONE_TIME_STEP_BUFFER_SIZE) + #error "Test buffers too small." +#endif + +// Update the buffer size if adding a unit test with larger buffer. +#define LARGEST_BUFFER_SIZE LSTM_1_BUFFER_SIZE + +int16_t buffer1[LARGEST_BUFFER_SIZE]; +int16_t buffer2[LARGEST_BUFFER_SIZE]; +int16_t buffer3[LARGEST_BUFFER_SIZE]; + +void lstm_1_arm_lstm_unidirectional_s8(void) +{ + int8_t output[LSTM_1_DST_SIZE] = {0}; + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + const int8_t *output_ref = lstm_1_output_ref; + const int32_t output_ref_size = LSTM_1_DST_SIZE; + + // Calculate kernel sums if using MVE-extension + int32_t input_data_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t forget_data_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t cell_data_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t output_data_kernel_sum[LSTM_1_NUMBER_UNITS]; + + int32_t input_hidden_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t forget_hidden_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t cell_hidden_kernel_sum[LSTM_1_NUMBER_UNITS]; + int32_t output_hidden_kernel_sum[LSTM_1_NUMBER_UNITS]; + + int32_t size_data = LSTM_1_NUMBER_INPUTS; + int32_t size_hidden = LSTM_1_NUMBER_UNITS; + + arm_vector_sum_s8(&input_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_1_input_to_input_w[0], + LSTM_1_DATA_OFFSET, + &lstm_1_input_gate_bias[0]); + arm_vector_sum_s8(&forget_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_1_input_to_forget_w[0], + LSTM_1_DATA_OFFSET, + &lstm_1_forget_gate_bias[0]); + arm_vector_sum_s8(&cell_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_1_input_to_cell_w[0], + LSTM_1_DATA_OFFSET, + &lstm_1_cell_gate_bias[0]); + arm_vector_sum_s8(&output_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_1_input_to_output_w[0], + LSTM_1_DATA_OFFSET, + &lstm_1_output_gate_bias[0]); + + arm_vector_sum_s8(&input_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_1_recurrent_input_to_input_w[0], + -LSTM_1_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&forget_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_1_recurrent_input_to_forget_w[0], + -LSTM_1_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&cell_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_1_recurrent_input_to_cell_w[0], + -LSTM_1_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&output_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_1_recurrent_input_to_output_w[0], + -LSTM_1_HIDDEN_OFFSET, + NULL); + + // INPUT GATE + const cmsis_nn_lstm_gate gate_input = {LSTM_1_IN_TO_INPUT_MULTIPLIER, + LSTM_1_IN_TO_INPUT_SHIFT, + &lstm_1_input_to_input_w[0], + &input_data_kernel_sum[0], + LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER, + LSTM_1_RECURRENT_TO_INPUT_SHIFT, + &lstm_1_recurrent_input_to_input_w[0], + &input_hidden_kernel_sum[0], + &lstm_1_input_gate_bias[0], + ARM_SIGMOID}; + + // FORGET GATE + const cmsis_nn_lstm_gate gate_forget = {LSTM_1_IN_TO_FORGET_MULTIPLIER, + LSTM_1_IN_TO_FORGET_SHIFT, + &lstm_1_input_to_forget_w[0], + &forget_data_kernel_sum[0], + LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER, + LSTM_1_RECURRENT_TO_FORGET_SHIFT, + &lstm_1_recurrent_input_to_forget_w[0], + &forget_hidden_kernel_sum[0], + &lstm_1_forget_gate_bias[0], + ARM_SIGMOID}; + + // CELL GATE + const cmsis_nn_lstm_gate gate_cell = {LSTM_1_IN_TO_CELL_MULTIPLIER, + LSTM_1_IN_TO_CELL_SHIFT, + &lstm_1_input_to_cell_w[0], + &cell_data_kernel_sum[0], + LSTM_1_RECURRENT_TO_CELL_MULTIPLIER, + LSTM_1_RECURRENT_TO_CELL_SHIFT, + &lstm_1_recurrent_input_to_cell_w[0], + &cell_hidden_kernel_sum[0], + &lstm_1_cell_gate_bias[0], + ARM_TANH}; + + // OUTPUT GATE + const cmsis_nn_lstm_gate gate_output = {LSTM_1_IN_TO_OUTPUT_MULTIPLIER, + LSTM_1_IN_TO_OUTPUT_SHIFT, + &lstm_1_input_to_output_w[0], + &output_data_kernel_sum[0], + LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER, + LSTM_1_RECURRENT_TO_OUTPUT_SHIFT, + &lstm_1_recurrent_input_to_output_w[0], + &output_hidden_kernel_sum[0], + &lstm_1_output_gate_bias[0], + ARM_SIGMOID}; + + // LSTM DATA + const cmsis_nn_lstm_params params = {LSTM_1_TIME_MAJOR, + LSTM_1_INPUT_BATCHES, + LSTM_1_TIME_STEPS, + LSTM_1_NUMBER_INPUTS, + LSTM_1_NUMBER_UNITS, + LSTM_1_DATA_OFFSET, + LSTM_1_FORGET_MULTIPLIER, + LSTM_1_FORGET_SHIFT, + LSTM_1_INPUT_MULTIPLIER, + LSTM_1_INPUT_SHIFT, + LSTM_1_IN_ACTIVATION_MAX, + LSTM_1_CELL_STATE_SHIFT, + LSTM_1_HIDDEN_MULTIPLIER, + LSTM_1_HIDDEN_SHIFT, + LSTM_1_HIDDEN_OFFSET, + gate_forget, + gate_input, + gate_cell, + gate_output}; + + // BUFFERS + cmsis_nn_lstm_context buffers; + buffers.temp1 = buffer1; + buffers.temp2 = buffer2; + buffers.cell_state = buffer3; + + arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_1_input, output, ¶ms, &buffers); + + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); +} + +void lstm_2_arm_lstm_unidirectional_s8(void) +{ + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + const int8_t *output_ref = lstm_2_output_ref; + const int32_t output_ref_size = LSTM_2_DST_SIZE; + + // Calculate kernel sums if using MVE-extension + int32_t input_data_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t forget_data_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t cell_data_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t output_data_kernel_sum[LSTM_2_NUMBER_UNITS]; + + int32_t input_hidden_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t forget_hidden_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t cell_hidden_kernel_sum[LSTM_2_NUMBER_UNITS]; + int32_t output_hidden_kernel_sum[LSTM_2_NUMBER_UNITS]; + + int32_t size_data = LSTM_2_NUMBER_INPUTS; + int32_t size_hidden = LSTM_2_NUMBER_UNITS; + + arm_vector_sum_s8(&input_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_2_input_to_input_w[0], + LSTM_2_DATA_OFFSET, + &lstm_2_input_gate_bias[0]); + arm_vector_sum_s8(&forget_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_2_input_to_forget_w[0], + LSTM_2_DATA_OFFSET, + &lstm_2_forget_gate_bias[0]); + arm_vector_sum_s8(&cell_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_2_input_to_cell_w[0], + LSTM_2_DATA_OFFSET, + &lstm_2_cell_gate_bias[0]); + arm_vector_sum_s8(&output_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_2_input_to_output_w[0], + LSTM_2_DATA_OFFSET, + &lstm_2_output_gate_bias[0]); + + arm_vector_sum_s8(&input_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_2_recurrent_input_to_input_w[0], + -LSTM_2_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&forget_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_2_recurrent_input_to_forget_w[0], + -LSTM_2_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&cell_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_2_recurrent_input_to_cell_w[0], + -LSTM_2_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&output_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_2_recurrent_input_to_output_w[0], + -LSTM_2_HIDDEN_OFFSET, + NULL); + + // INPUT GATE + const cmsis_nn_lstm_gate gate_input = {LSTM_2_IN_TO_INPUT_MULTIPLIER, + LSTM_2_IN_TO_INPUT_SHIFT, + &lstm_2_input_to_input_w[0], + &input_data_kernel_sum[0], + LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER, + LSTM_2_RECURRENT_TO_INPUT_SHIFT, + &lstm_2_recurrent_input_to_input_w[0], + &input_hidden_kernel_sum[0], + &lstm_2_input_gate_bias[0], + ARM_SIGMOID}; + + // FORGET GATE + const cmsis_nn_lstm_gate gate_forget = {LSTM_2_IN_TO_FORGET_MULTIPLIER, + LSTM_2_IN_TO_FORGET_SHIFT, + &lstm_2_input_to_forget_w[0], + &forget_data_kernel_sum[0], + LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER, + LSTM_2_RECURRENT_TO_FORGET_SHIFT, + &lstm_2_recurrent_input_to_forget_w[0], + &forget_hidden_kernel_sum[0], + &lstm_2_forget_gate_bias[0], + ARM_SIGMOID}; + + // CELL GATE + const cmsis_nn_lstm_gate gate_cell = {LSTM_2_IN_TO_CELL_MULTIPLIER, + LSTM_2_IN_TO_CELL_SHIFT, + &lstm_2_input_to_cell_w[0], + &cell_data_kernel_sum[0], + LSTM_2_RECURRENT_TO_CELL_MULTIPLIER, + LSTM_2_RECURRENT_TO_CELL_SHIFT, + &lstm_2_recurrent_input_to_cell_w[0], + &cell_hidden_kernel_sum[0], + &lstm_2_cell_gate_bias[0], + ARM_TANH}; + + // OUTPUT GATE + const cmsis_nn_lstm_gate gate_output = {LSTM_2_IN_TO_OUTPUT_MULTIPLIER, + LSTM_2_IN_TO_OUTPUT_SHIFT, + &lstm_2_input_to_output_w[0], + &output_data_kernel_sum[0], + LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER, + LSTM_2_RECURRENT_TO_OUTPUT_SHIFT, + &lstm_2_recurrent_input_to_output_w[0], + &output_hidden_kernel_sum[0], + &lstm_2_output_gate_bias[0], + ARM_SIGMOID}; + + // LSTM DATA + const cmsis_nn_lstm_params params = {LSTM_2_TIME_MAJOR, + LSTM_2_INPUT_BATCHES, + LSTM_2_TIME_STEPS, + LSTM_2_NUMBER_INPUTS, + LSTM_2_NUMBER_UNITS, + LSTM_2_DATA_OFFSET, + LSTM_2_FORGET_MULTIPLIER, + LSTM_2_FORGET_SHIFT, + LSTM_2_INPUT_MULTIPLIER, + LSTM_2_INPUT_SHIFT, + LSTM_2_IN_ACTIVATION_MAX, + LSTM_2_CELL_STATE_SHIFT, + LSTM_2_HIDDEN_MULTIPLIER, + LSTM_2_HIDDEN_SHIFT, + LSTM_2_HIDDEN_OFFSET, + gate_forget, + gate_input, + gate_cell, + gate_output}; + + // BUFFERS + cmsis_nn_lstm_context buffers; + buffers.temp1 = buffer1; + buffers.temp2 = buffer2; + buffers.cell_state = buffer3; + + int8_t output[LSTM_2_DST_SIZE] = {0}; + arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_2_input, output, ¶ms, &buffers); + + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); +} + +void lstm_one_time_step_arm_lstm_unidirectional_s8(void) +{ + int8_t output[LSTM_ONE_TIME_STEP_DST_SIZE] = {0}; + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + const int8_t *output_ref = lstm_one_time_step_output_ref; + const int32_t output_ref_size = LSTM_ONE_TIME_STEP_DST_SIZE; + + // Calculate kernel sums if using MVE-extension + int32_t input_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t forget_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t cell_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t output_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + + int32_t input_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t forget_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t cell_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + int32_t output_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS]; + + int32_t size_data = LSTM_ONE_TIME_STEP_NUMBER_INPUTS; + int32_t size_hidden = LSTM_ONE_TIME_STEP_NUMBER_UNITS; + + arm_vector_sum_s8(&input_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_one_time_step_input_to_input_w[0], + LSTM_ONE_TIME_STEP_DATA_OFFSET, + &lstm_one_time_step_input_gate_bias[0]); + arm_vector_sum_s8(&forget_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_one_time_step_input_to_forget_w[0], + LSTM_ONE_TIME_STEP_DATA_OFFSET, + &lstm_one_time_step_forget_gate_bias[0]); + arm_vector_sum_s8(&cell_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_one_time_step_input_to_cell_w[0], + LSTM_ONE_TIME_STEP_DATA_OFFSET, + &lstm_one_time_step_cell_gate_bias[0]); + arm_vector_sum_s8(&output_data_kernel_sum[0], + size_data, + size_hidden, + &lstm_one_time_step_input_to_output_w[0], + LSTM_ONE_TIME_STEP_DATA_OFFSET, + &lstm_one_time_step_output_gate_bias[0]); + + arm_vector_sum_s8(&input_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_one_time_step_recurrent_input_to_input_w[0], + -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&forget_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_one_time_step_recurrent_input_to_forget_w[0], + -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&cell_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_one_time_step_recurrent_input_to_cell_w[0], + -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET, + NULL); + arm_vector_sum_s8(&output_hidden_kernel_sum[0], + size_hidden, + size_hidden, + &lstm_one_time_step_recurrent_input_to_output_w[0], + -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET, + NULL); + + // INPUT GATE + const cmsis_nn_lstm_gate gate_input = {LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER, + LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT, + &lstm_one_time_step_input_to_input_w[0], + input_data_kernel_sum, + LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER, + LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT, + &lstm_one_time_step_recurrent_input_to_input_w[0], + input_hidden_kernel_sum, + &lstm_one_time_step_input_gate_bias[0], + ARM_SIGMOID}; + + // FORGET GATE + const cmsis_nn_lstm_gate gate_forget = {LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER, + LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT, + &lstm_one_time_step_input_to_forget_w[0], + forget_data_kernel_sum, + LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER, + LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT, + &lstm_one_time_step_recurrent_input_to_forget_w[0], + forget_hidden_kernel_sum, + &lstm_one_time_step_forget_gate_bias[0], + ARM_SIGMOID}; + + // CELL GATE + const cmsis_nn_lstm_gate gate_cell = {LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER, + LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT, + &lstm_one_time_step_input_to_cell_w[0], + cell_data_kernel_sum, + LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER, + LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT, + &lstm_one_time_step_recurrent_input_to_cell_w[0], + cell_hidden_kernel_sum, + &lstm_one_time_step_cell_gate_bias[0], + ARM_TANH}; + + // OUTPUT GATE + const cmsis_nn_lstm_gate gate_output = {LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER, + LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT, + &lstm_one_time_step_input_to_output_w[0], + output_data_kernel_sum, + LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER, + LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT, + &lstm_one_time_step_recurrent_input_to_output_w[0], + output_hidden_kernel_sum, + &lstm_one_time_step_output_gate_bias[0], + ARM_SIGMOID}; + + // LSTM DATA + const cmsis_nn_lstm_params params = {LSTM_ONE_TIME_STEP_TIME_MAJOR, + LSTM_ONE_TIME_STEP_INPUT_BATCHES, + LSTM_ONE_TIME_STEP_TIME_STEPS, + LSTM_ONE_TIME_STEP_NUMBER_INPUTS, + LSTM_ONE_TIME_STEP_NUMBER_UNITS, + LSTM_ONE_TIME_STEP_DATA_OFFSET, + LSTM_ONE_TIME_STEP_FORGET_MULTIPLIER, + LSTM_ONE_TIME_STEP_FORGET_SHIFT, + LSTM_ONE_TIME_STEP_INPUT_MULTIPLIER, + LSTM_ONE_TIME_STEP_INPUT_SHIFT, + LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX, + LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT, + LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER, + LSTM_ONE_TIME_STEP_HIDDEN_SHIFT, + LSTM_ONE_TIME_STEP_HIDDEN_OFFSET, + gate_forget, + gate_input, + gate_cell, + gate_output}; + + // BUFFERS + cmsis_nn_lstm_context buffers; + buffers.temp1 = buffer1; + buffers.temp2 = buffer2; + buffers.cell_state = buffer3; + + arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_one_time_step_input, output, ¶ms, &buffers); + + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); +} diff --git a/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c b/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c index 43f7d26e..857e35ca 100644 --- a/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -76,7 +76,7 @@ void svdf_int8_arm_svdf_s8(void) #if defined(ARM_MATH_MVEI) int32_t *kernel_sum_buf = ctx.buf; - arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data); + arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL); #endif // + SVDF_INT8_TIME_BATCHES additional bytes to make sure it is not overwritten @@ -191,7 +191,7 @@ void svdf_int8_2_arm_svdf_s8(void) #if defined(ARM_MATH_MVEI) int32_t *kernel_sum_buf = ctx.buf; - arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data); + arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL); #endif const int state_data_size = sizeof(svdf_int8_2_state); diff --git a/Tests/UnitTest/generate_test_data.py b/Tests/UnitTest/generate_test_data.py index b82c53d7..c1e6ca6d 100755 --- a/Tests/UnitTest/generate_test_data.py +++ b/Tests/UnitTest/generate_test_data.py @@ -721,7 +721,8 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases, dilation_x=3, dilation_y=3, pad=True, - interpreter=interpreter) dataset = 'basic_int4' + interpreter=interpreter) + dataset = 'basic_int4' testdata_sets[dataset] = ConvSettings(dataset, type_of_test, regenerate_weights, @@ -2848,7 +2849,7 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases, regenerate_input, regenerate_biases, schema_file, - batches=2, + batches=1, time_steps=9, number_inputs=6, number_units=7, diff --git a/Tests/UnitTest/lstm_settings.py b/Tests/UnitTest/lstm_settings.py index 7846a571..4a0ce39a 100644 --- a/Tests/UnitTest/lstm_settings.py +++ b/Tests/UnitTest/lstm_settings.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -252,6 +252,7 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias self.generate_c_array("output_norm_coeff", interpreter.get_tensor(output_norm_coeff['index'])) input_scale = input_data_for_index['quantization_parameters']['scales'][0] + self.data_zp = input_data_for_index['quantization_parameters']['zero_points'][0] cell_scale = cell_state['quantization_parameters']['scales'][0] output_state_scale = output_state['quantization_parameters']['scales'][0] input_zp = input_data_for_index['quantization_parameters']['zero_points'][0] @@ -263,7 +264,7 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias tmp = math.log(cell_scale) * (1 / math.log(2)) self.cell_state_shift = int(round(tmp)) - self.calc_scales(input_scale, output_state_scale) + self.calc_scales(input_scale, output_state_scale, cell_scale) # Calculate effective biases. input_zp = -input_zp @@ -293,14 +294,21 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias self.generate_c_array("recurrent_to_output_eff_bias", recurrent_to_output_eff_bias, datatype='int32_t') # Generate reference - interpreter.invoke() - output_data = interpreter.get_tensor(output_details[0]["index"]) + if self.use_tflite_micro_interpreter: + interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite)) + interpreter.set_input(tf.cast(input_data, tf.int8), input_details[0]["index"]) + interpreter.invoke() + output_data = interpreter.get_output(0) + else: + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]["index"]) + self.generate_c_array(self.output_data_file_prefix, output_data, datatype='int8_t') self.write_c_config_header() self.write_c_header_wrapper() - def calc_scales(self, input_scale, output_state_scale): + def calc_scales(self, input_scale, output_state_scale, cell_scale): intermediate_scale = pow(2, -12) if self.time_major: @@ -308,6 +316,9 @@ def calc_scales(self, input_scale, output_state_scale): else: time_major_offset = 0 + + self.effective_forget_scale = pow(2, -15) / cell_scale * cell_scale + self.effective_input_scale = pow(2, -15) / cell_scale * pow(2, -15) self.effective_hidden_scale = pow(2, -15) / output_state_scale * pow(2, -15) self.i2i_effective_scale = input_scale * self.lstm_scales[self.input_to_input_w_index + time_major_offset][0] \ @@ -393,11 +404,21 @@ def write_c_config_header(self) -> None: f.write("#define {}_RECURRENT_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier)) f.write("#define {}_RECURRENT_TO_OUTPUT_SHIFT {}\n".format(prefix, shift)) + + (multiplier, shift) = self.quantize_scale(self.effective_forget_scale) + f.write("#define {}_FORGET_MULTIPLIER {}\n".format(prefix, multiplier)) + f.write("#define {}_FORGET_SHIFT {}\n".format(prefix, shift)) + + (multiplier, shift) = self.quantize_scale(self.effective_input_scale) + f.write("#define {}_INPUT_MULTIPLIER {}\n".format(prefix, multiplier)) + f.write("#define {}_INPUT_SHIFT {}\n".format(prefix, shift)) + (multiplier, shift) = self.quantize_scale(self.effective_hidden_scale) f.write("#define {}_HIDDEN_MULTIPLIER {}\n".format(prefix, multiplier)) f.write("#define {}_HIDDEN_SHIFT {}\n".format(prefix, shift)) f.write("#define {}_HIDDEN_OFFSET {}\n".format(prefix, self.hidden_zp)) + f.write("#define {}_DATA_OFFSET {}\n".format(prefix, -self.data_zp)) f.write("#define {}_OUTPUT_STATE_OFFSET {}\n".format(prefix, self.output_state_offset)) f.write("#define {}_CELL_STATE_SHIFT {}\n".format(prefix, self.cell_state_shift))