From 601d96c63a4e1437d9a90c0f917684405e60d8a1 Mon Sep 17 00:00:00 2001
From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com>
Date: Wed, 7 Feb 2024 16:09:56 +0100
Subject: [PATCH] Reimplement arm_lstm_unidirectional_s8 (#102)
- API changes
- Optimized for scalar, DSP and MVE
- Bit exact to TFLM reference kernel
- Less scratch-buffer usage
---
ARM.CMSIS-NN.pdsc | 9 +-
Include/arm_nn_types.h | 134 ++---
Include/arm_nnfunctions.h | 101 ++--
Include/arm_nnsupportfunctions.h | 215 ++++----
.../arm_nn_activation_s16.c | 38 +-
.../arm_elementwise_mul_acc_s16.c | 167 ++++++
.../arm_elementwise_mul_s16_s8.c | 100 ++--
.../arm_vector_sum_s8.c | 97 ++--
Source/LSTMFunctions/CMakeLists.txt | 6 +-
.../arm_lstm_unidirectional_s8.c | 80 +++
.../arm_lstm_unidirectional_s8_s16.c | 184 -------
.../arm_nn_lstm_calculate_gate_s8_s16.c | 80 ++-
.../NNSupportFunctions/arm_nn_lstm_step_s8.c | 110 ++++
.../arm_nn_lstm_step_s8_s16.c | 154 ------
.../arm_nn_lstm_update_cell_state_s16.c | 124 -----
.../arm_nn_lstm_update_output_s8_s16.c | 81 ---
...=> arm_nn_vec_mat_mul_result_acc_s8_s16.c} | 220 ++++----
Tests/UnitTest/CMakeLists.txt | 2 +-
Tests/UnitTest/README.md | 16 +-
.../TestData/lstm_1/cell_gate_bias_data.h | 6 +-
.../TestData/lstm_1/cell_norm_coeff_data.h | 4 +-
.../TestData/lstm_1/cell_state_data.h | 4 +-
.../TestData/lstm_1/cell_to_forget_data.h | 4 +-
.../TestData/lstm_1/cell_to_input_data.h | 4 +-
.../TestData/lstm_1/cell_to_output_data.h | 4 +-
.../TestCases/TestData/lstm_1/config_data.h | 31 +-
.../TestData/lstm_1/forget_gate_bias_data.h | 6 +-
.../TestData/lstm_1/forget_norm_coeff_data.h | 4 +-
.../TestCases/TestData/lstm_1/input_data.h | 26 +-
.../TestData/lstm_1/input_gate_bias_data.h | 7 +-
.../TestData/lstm_1/input_norm_coeff_data.h | 4 +-
.../lstm_1/input_to_cell_eff_bias_data.h | 6 +-
.../TestData/lstm_1/input_to_cell_w_data.h | 29 +-
.../lstm_1/input_to_forget_eff_bias_data.h | 6 +-
.../TestData/lstm_1/input_to_forget_w_data.h | 30 +-
.../lstm_1/input_to_input_eff_bias_data.h | 6 +-
.../TestData/lstm_1/input_to_input_w_data.h | 30 +-
.../lstm_1/input_to_output_eff_bias_data.h | 6 +-
.../TestData/lstm_1/input_to_output_w_data.h | 30 +-
.../TestData/lstm_1/output_gate_bias_data.h | 6 +-
.../TestData/lstm_1/output_norm_coeff_data.h | 4 +-
.../TestData/lstm_1/output_ref_data.h | 14 +-
.../TestData/lstm_1/output_state_data.h | 6 +-
.../TestData/lstm_1/projection_bias_data.h | 4 +-
.../TestData/lstm_1/projection_weights_data.h | 4 +-
.../lstm_1/recurrent_input_to_cell_w_data.h | 17 +-
.../lstm_1/recurrent_input_to_forget_w_data.h | 16 +-
.../lstm_1/recurrent_input_to_input_w_data.h | 16 +-
.../lstm_1/recurrent_input_to_output_w_data.h | 17 +-
.../lstm_1/recurrent_to_cell_eff_bias_data.h | 6 +-
.../recurrent_to_forget_eff_bias_data.h | 6 +-
.../lstm_1/recurrent_to_input_eff_bias_data.h | 6 +-
.../recurrent_to_output_eff_bias_data.h | 6 +-
.../TestCases/TestData/lstm_1/test_data.h | 4 +-
.../TestData/lstm_2/cell_gate_bias_data.h | 6 +-
.../TestData/lstm_2/cell_norm_coeff_data.h | 4 +-
.../TestData/lstm_2/cell_state_data.h | 6 +-
.../TestData/lstm_2/cell_to_forget_data.h | 4 +-
.../TestData/lstm_2/cell_to_input_data.h | 4 +-
.../TestData/lstm_2/cell_to_output_data.h | 4 +-
.../TestCases/TestData/lstm_2/config_data.h | 43 +-
.../TestData/lstm_2/forget_gate_bias_data.h | 6 +-
.../TestData/lstm_2/forget_norm_coeff_data.h | 4 +-
.../TestCases/TestData/lstm_2/input_data.h | 15 +-
.../TestData/lstm_2/input_gate_bias_data.h | 6 +-
.../TestData/lstm_2/input_norm_coeff_data.h | 4 +-
.../lstm_2/input_to_cell_eff_bias_data.h | 6 +-
.../TestData/lstm_2/input_to_cell_w_data.h | 10 +-
.../lstm_2/input_to_forget_eff_bias_data.h | 6 +-
.../TestData/lstm_2/input_to_forget_w_data.h | 10 +-
.../lstm_2/input_to_input_eff_bias_data.h | 6 +-
.../TestData/lstm_2/input_to_input_w_data.h | 10 +-
.../lstm_2/input_to_output_eff_bias_data.h | 6 +-
.../TestData/lstm_2/input_to_output_w_data.h | 10 +-
.../TestData/lstm_2/output_gate_bias_data.h | 6 +-
.../TestData/lstm_2/output_norm_coeff_data.h | 4 +-
.../TestData/lstm_2/output_ref_data.h | 15 +-
.../TestData/lstm_2/output_state_data.h | 6 +-
.../TestData/lstm_2/projection_bias_data.h | 4 +-
.../TestData/lstm_2/projection_weights_data.h | 4 +-
.../lstm_2/recurrent_input_to_cell_w_data.h | 12 +-
.../lstm_2/recurrent_input_to_forget_w_data.h | 10 +-
.../lstm_2/recurrent_input_to_input_w_data.h | 10 +-
.../lstm_2/recurrent_input_to_output_w_data.h | 10 +-
.../lstm_2/recurrent_to_cell_eff_bias_data.h | 6 +-
.../recurrent_to_forget_eff_bias_data.h | 6 +-
.../lstm_2/recurrent_to_input_eff_bias_data.h | 6 +-
.../recurrent_to_output_eff_bias_data.h | 6 +-
.../TestCases/TestData/lstm_2/test_data.h | 4 +-
.../lstm_one_time_step/cell_gate_bias_data.h | 6 +-
.../lstm_one_time_step/cell_norm_coeff_data.h | 4 +-
.../lstm_one_time_step/cell_state_data.h | 4 +-
.../lstm_one_time_step/cell_to_forget_data.h | 4 +-
.../lstm_one_time_step/cell_to_input_data.h | 4 +-
.../lstm_one_time_step/cell_to_output_data.h | 4 +-
.../TestData/lstm_one_time_step/config_data.h | 37 +-
.../forget_gate_bias_data.h | 6 +-
.../forget_norm_coeff_data.h | 4 +-
.../TestData/lstm_one_time_step/input_data.h | 11 +-
.../lstm_one_time_step/input_gate_bias_data.h | 6 +-
.../input_norm_coeff_data.h | 4 +-
.../input_to_cell_eff_bias_data.h | 6 +-
.../lstm_one_time_step/input_to_cell_w_data.h | 10 +-
.../input_to_forget_eff_bias_data.h | 6 +-
.../input_to_forget_w_data.h | 10 +-
.../input_to_input_eff_bias_data.h | 6 +-
.../input_to_input_w_data.h | 10 +-
.../input_to_output_eff_bias_data.h | 6 +-
.../input_to_output_w_data.h | 11 +-
.../output_gate_bias_data.h | 6 +-
.../output_norm_coeff_data.h | 4 +-
.../lstm_one_time_step/output_ref_data.h | 6 +-
.../lstm_one_time_step/output_state_data.h | 6 +-
.../lstm_one_time_step/projection_bias_data.h | 4 +-
.../projection_weights_data.h | 4 +-
.../recurrent_input_to_cell_w_data.h | 6 +-
.../recurrent_input_to_forget_w_data.h | 6 +-
.../recurrent_input_to_input_w_data.h | 6 +-
.../recurrent_input_to_output_w_data.h | 6 +-
.../recurrent_to_cell_eff_bias_data.h | 6 +-
.../recurrent_to_forget_eff_bias_data.h | 6 +-
.../recurrent_to_input_eff_bias_data.h | 6 +-
.../recurrent_to_output_eff_bias_data.h | 6 +-
.../TestData/lstm_one_time_step/test_data.h | 4 +-
.../test_arm_ds_cnn_l_s8.c | 2 +-
.../test_arm_fully_connected_s8.c | 12 +-
.../test_arm_lstm_unidirectional_s16_s8.c | 328 ------------
.../CMakeLists.txt | 10 +-
.../unity_test_arm_lstm_unidirectional_s8.c} | 13 +-
.../test_arm_lstm_unidirectional_s8.c | 495 ++++++++++++++++++
.../test_arm_svdf_s8/test_arm_svdf_s8.c | 6 +-
Tests/UnitTest/generate_test_data.py | 5 +-
Tests/UnitTest/lstm_settings.py | 31 +-
133 files changed, 1830 insertions(+), 1889 deletions(-)
create mode 100644 Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c
create mode 100644 Source/LSTMFunctions/arm_lstm_unidirectional_s8.c
delete mode 100644 Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c
create mode 100644 Source/NNSupportFunctions/arm_nn_lstm_step_s8.c
delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c
delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c
delete mode 100644 Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c
rename Source/NNSupportFunctions/{arm_nn_vec_mat_mul_result_acc_s8.c => arm_nn_vec_mat_mul_result_acc_s8_s16.c} (60%)
delete mode 100644 Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c
rename Tests/UnitTest/TestCases/{test_arm_lstm_unidirectional_s16_s8 => test_arm_lstm_unidirectional_s8}/CMakeLists.txt (67%)
rename Tests/UnitTest/TestCases/{test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c => test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c} (69%)
create mode 100644 Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c
diff --git a/ARM.CMSIS-NN.pdsc b/ARM.CMSIS-NN.pdsc
index f216d8cc..03d3783d 100644
--- a/ARM.CMSIS-NN.pdsc
+++ b/ARM.CMSIS-NN.pdsc
@@ -83,6 +83,7 @@
+
@@ -107,18 +108,16 @@
-
-
-
+
-
+
-
+
diff --git a/Include/arm_nn_types.h b/Include/arm_nn_types.h
index 257442da..c567f0c1 100644
--- a/Include/arm_nn_types.h
+++ b/Include/arm_nn_types.h
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -22,8 +22,8 @@
* Description: Public header file to contain the CMSIS-NN structs for the
* TensorFlowLite micro compliant functions
*
- * $Date: 9 January 2024
- * $Revision: V.2.6.2
+ * $Date: 19 January 2024
+ * $Revision: V.3.0.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
@@ -40,7 +40,6 @@
* @{
*/
-
/** Enum for specifying activation function types */
typedef enum
{
@@ -180,31 +179,6 @@ typedef struct
const int16_t *one_by_one_lut;
} cmsis_nn_softmax_lut_s16;
-/** LSTM guard parameters */
-typedef struct
-{
- int32_t input_variance;
- int32_t forget_variance;
- int32_t cell_variance;
- int32_t output_variance;
-} cmsis_nn_lstm_guard_params;
-
-/** LSTM scratch buffer container */
-typedef struct
-{
- int16_t *input_gate;
- int16_t *forget_gate;
- int16_t *cell_gate;
- int16_t *output_gate;
-} cmsis_nn_lstm_context;
-
-/** Quantized clip value for cell and projection of LSTM input. Zero value means no clipping. */
-typedef struct
-{
- int16_t cell;
- int8_t projection;
-} cmsis_nn_lstm_clip_params;
-
/** CMSIS-NN object for quantization parameters */
typedef struct
{
@@ -212,70 +186,60 @@ typedef struct
int32_t shift; /**< Shift value */
} cmsis_nn_scaling;
-/** CMSIS-NN norm layer coefficients */
+/** CMSIS-NN object for LSTM gate parameters*/
typedef struct
{
- int16_t *input_weight;
- int16_t *forget_weight;
- int16_t *cell_weight;
- int16_t *output_weight;
-} cmsis_nn_layer_norm;
+ int32_t input_multiplier;
+ int32_t input_shift;
+ const int8_t *input_weights;
+ const int32_t *input_effective_bias; /**< Bias added with precomputed kernel_sum * lhs_offset*/
-/** Parameters for integer LSTM, as defined in TFLM */
+ int32_t hidden_multiplier;
+ int32_t hidden_shift;
+ const int8_t *hidden_weights;
+ const int32_t *hidden_effective_bias; /**< Precomputed kernel_sum * lhs_offset*/
+
+ const int32_t *bias;
+ arm_nn_activation_type activation_type;
+} cmsis_nn_lstm_gate;
+
+/** CMSIS-NN object for LSTM parameters*/
typedef struct
{
- int32_t time_major; /**< Nonzero (true) if first row of data is timestamps for input */
- cmsis_nn_scaling input_to_input_scaling;
- cmsis_nn_scaling input_to_forget_scaling;
- cmsis_nn_scaling input_to_cell_scaling;
- cmsis_nn_scaling input_to_output_scaling;
- cmsis_nn_scaling recurrent_to_input_scaling;
- cmsis_nn_scaling recurrent_to_forget_scaling;
- cmsis_nn_scaling recurrent_to_cell_scaling;
- cmsis_nn_scaling recurrent_to_output_scaling;
- cmsis_nn_scaling cell_to_input_scaling;
- cmsis_nn_scaling cell_to_forget_scaling;
- cmsis_nn_scaling cell_to_output_scaling;
- cmsis_nn_scaling projection_scaling;
- cmsis_nn_scaling hidden_scaling;
- cmsis_nn_scaling layer_norm_input_scaling; /**< layer normalization for input layer */
- cmsis_nn_scaling layer_norm_forget_scaling; /**< layer normalization for forget gate */
- cmsis_nn_scaling layer_norm_cell_scaling; /**< layer normalization for cell */
- cmsis_nn_scaling layer_norm_output_scaling; /**< layer normalization for outpus layer */
-
- int32_t cell_state_shift;
- int32_t hidden_offset;
- int32_t output_state_offset;
-
- cmsis_nn_lstm_clip_params clip;
- cmsis_nn_lstm_guard_params guard;
- cmsis_nn_layer_norm layer_norm;
-
- /* Effective bias is precalculated as bias + zero_point * weight.
- Only applicable to when input/output are s8 and weights are s16 */
- const int32_t *i2i_effective_bias; /**< input to input effective bias */
- const int32_t *i2f_effective_bias; /**< input to forget gate effective bias */
- const int32_t *i2c_effective_bias; /**< input to cell effective bias */
- const int32_t *i2o_effective_bias; /**< input to output effective bias */
-
- const int32_t *r2i_effective_bias; /**< recurrent gate to input effective bias */
- const int32_t *r2f_effective_bias; /**< recurrent gate to forget gate effective bias */
- const int32_t *r2c_effective_bias; /**< recurrent gate to cell effective bias */
- const int32_t *r2o_effective_bias; /**< recurrent gate to output effective bias */
-
- const int32_t *projection_effective_bias;
-
- /* Not precalculated bias */
- const int32_t *input_gate_bias;
- const int32_t *forget_gate_bias;
- const int32_t *cell_gate_bias;
- const int32_t *output_gate_bias;
-
- /* Activation min and max */
- cmsis_nn_activation activation;
+ int32_t time_major; /**< 0 if first dimension is batch, else first dimension is time */
+ int32_t batch_size;
+ int32_t time_steps;
+ int32_t input_size; /**< Size of new data input into the LSTM cell*/
+ int32_t
+ hidden_size; /**< Size of output from the LSTM cell, used as output and recursively into the next time step*/
+
+ int32_t input_offset;
+
+ int32_t forget_to_cell_multiplier;
+ int32_t forget_to_cell_shift;
+ int32_t input_to_cell_multiplier;
+ int32_t input_to_cell_shift;
+ int32_t cell_clip; /**< Min/max value of cell output*/
+ int32_t cell_scale_power;
+ int32_t output_multiplier;
+ int32_t output_shift;
+ int32_t output_offset;
+
+ cmsis_nn_lstm_gate forget_gate;
+ cmsis_nn_lstm_gate input_gate;
+ cmsis_nn_lstm_gate cell_gate;
+ cmsis_nn_lstm_gate output_gate;
} cmsis_nn_lstm_params;
+/** CMSIS-NN object for LSTM scratch buffers*/
+typedef struct
+{
+ void *temp1;
+ void *temp2;
+ void *cell_state;
+} cmsis_nn_lstm_context;
+
/**
* @} // end group genPubTypes
*/
diff --git a/Include/arm_nnfunctions.h b/Include/arm_nnfunctions.h
index ade3417c..70a106d9 100644
--- a/Include/arm_nnfunctions.h
+++ b/Include/arm_nnfunctions.h
@@ -21,8 +21,9 @@
* Title: arm_nnfunctions.h
* Description: Public header file for CMSIS NN Library
*
- * $Date: 11 January 2024
- * $Revision: V.12.6.0
+ * $Date: 19 January 2024
+ * $Revision: V.13.0.0
+
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
@@ -1514,11 +1515,13 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
int8_t *output_data);
/**
- * @brief Calculate vector sums that may be required by arm_fully_connected_s8().
+ * @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add bias_data.
* @param[in, out] vector_sum_buf Buffer for vector sums
* @param[in] vector_cols Number of vector columns
* @param[in] vector_rows Number of vector rows
- * @param[in] vector_data Vector or weigths data
+ * @param[in] vector_data Vector of weigths data
+ * @param[in] lhs_offset Constant multiplied with each sum
+ * @param[in] bias_data Vector of bias data, added to each sum.
* @return The function returns
* ARM_CMSIS_NN_SUCCESS
- Successful operation
* ARM_CMSIS_NN_ARG_ERROR
- If not for Arm(R) Helium Architecture case.
@@ -1526,7 +1529,9 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_cols,
const int32_t vector_rows,
- const int8_t *vector_data);
+ const int8_t *vector_data,
+ const int32_t lhs_offset,
+ const int32_t *bias_data);
/**
* @brief Get size of additional buffer required by arm_fully_connected_s8().
@@ -1802,21 +1807,23 @@ void arm_relu_q15(int16_t *data, uint16_t size);
/**
* @brief s16 neural network activation function using direct table look-up
- * @param[in] input pointer to input data
+ * @param[in] input pointer to input data
* @param[out] output pointer to output
* @param[in] size number of elements
- * @param[in] left_shift bit-width of the integer part, assume to be smaller than 3
+ * @param[in] left_shift bit-width of the integer part, assumed to be smaller than 3.
* @param[in] type type of activation functions
+ * @return The function returns ARM_CMSIS_NN_SUCCESS
+
*
* @details Supported framework: TensorFlow Lite for Microcontrollers.
- * This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid actication
+ * This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid activation
* functions
*/
-void arm_nn_activation_s16(const int16_t *input,
- int16_t *output,
- const uint16_t size,
- const uint16_t left_shift,
- const arm_nn_activation_type type);
+arm_cmsis_nn_status arm_nn_activation_s16(const int16_t *input,
+ int16_t *output,
+ const int32_t size,
+ const int32_t left_shift,
+ const arm_nn_activation_type type);
/**
* @defgroup Pooling Pooling Functions
@@ -2441,67 +2448,25 @@ arm_cmsis_nn_status arm_svdf_state_s16_s8(const cmsis_nn_context *input_ctx,
*/
/**
- * @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output
- * Peephole connections, projection, clipping, combined input/forget gate and layer normalization are not supported.
- *
- * @param[in] scratch_buffers Struct containing scratch buffers
- * Expected size for each scratch buffer is
- * lstm_dims->num_batches * lstm_dims->num_outputs.
- * @param[in] input_data Pointer to input data
- * @param[in] lstm_dims LSTM input parameters related to dimensions
- * @param[in] input_to_input_weights Input to input weights
- * @param[in] input_to_forget_weights Input to forget weights
- * @param[in] input_to_cell_weights Input to cell weights
- * @param[in] input_to_output_weights Input to output weights
- * @param[in] recurrent_to_input_weights Recurrent to input weights
- * @param[in] recurrent_to_forget_weights Recurrent to forget weights
- * @param[in] recurrent_to_cell_weights Recurrent to cell weights
- * @param[in] recurrent_to_output_weights Recurrent to output weights
- * @param[in] cell_to_input_weights Cell to input weights. Not used.
- * @param[in] cell_to_forget_weights Cell to forget weights. Not used.
- * @param[in] cell_to_output_weights Cell to output weights. Not used.
- * @param[in] projection_weights Projection weights. Not used.
- * @param[in] lstm LSTM parameters. See struct declaration
- * @param[in] output_state Pointer to (recurrent) output state
- * @param[in] cell_state Pointer to cell state
- * @param[in] output_data Pointer to output state
- *
- * @note Following assumptions are done based on LSTM functionality as supported by
- * Keras version 2.9.0 at the time of development. As stated here,
- * https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md
- * Keras's LSTMCell is equivalent to TensorFlow's BasicLSTMCell,
- * which does not support peephole, clipping or projection.
- * Layer normalization and combined input/forget gate are not supported either.
- *
- * 1 Input to input weight can not be nullptr. Otherwise nullptr for combined input/forgat gate.
- * 2 Cell weights are not used and should be nullptr. Otherwise needed for peephole connections.
- * 3 Projection weight is not used and should be nullpr. Otherwise needed for projection.
+ * @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output.
+ *
+ * @param[in] input Pointer to input data
+ * @param[out] output Pointer to output data
+ * @param[in] params Struct containing all information about the lstm operator, see arm_nn_types.
+ * @param[in] buffers Struct containing pointers to all temporary scratch buffers needed for the
+ * lstm operator, see arm_nn_types.
+ *
*
* @return The function returns ARM_CMSIS_NN_SUCCESS
*
* @details
- * 1. Supported framework: TensorFlow Lite micro
+ * 1. Supported framework: TensorFlow Lite Micro
*
*/
-arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers,
- const int8_t *input_data,
- const cmsis_nn_lstm_dims *lstm_dims,
- const int8_t *input_to_input_weights,
- const int8_t *input_to_forget_weights,
- const int8_t *input_to_cell_weights,
- const int8_t *input_to_output_weights,
- const int8_t *recurrent_to_input_weights,
- const int8_t *recurrent_to_forget_weights,
- const int8_t *recurrent_to_cell_weights,
- const int8_t *recurrent_to_output_weights,
- const int16_t *cell_to_input_weights,
- const int16_t *cell_to_forget_weights,
- const int16_t *cell_to_output_weights,
- const int8_t *projection_weights,
- const cmsis_nn_lstm_params *lstm,
- int8_t *output_state,
- int16_t *cell_state,
- int8_t *output_data);
+arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
+ int8_t *output,
+ const cmsis_nn_lstm_params *params,
+ cmsis_nn_lstm_context *buffers);
/**
* @brief Get size of additional buffer required by arm_svdf_s8().
diff --git a/Include/arm_nnsupportfunctions.h b/Include/arm_nnsupportfunctions.h
index 20cbfd38..f2393db0 100644
--- a/Include/arm_nnsupportfunctions.h
+++ b/Include/arm_nnsupportfunctions.h
@@ -21,8 +21,8 @@
* Title: arm_nnsupportfunctions.h
* Description: Public header file of support functions for CMSIS NN Library
*
- * $Date: 11 January 2024
- * $Revision: V.17.7.0
+ * $Date: 19 January 2024
+ * $Revision: V.18.0.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
@@ -1453,142 +1453,77 @@ __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)
/**
* @brief Update LSTM function for an iteration step
*
- * param[in] input Input data
- * param[in] input_to_input_weight Input to input gate weights
- * param[in] input_to_forget_weight Input to forget gate weights
- * param[in] input_to_cell_weight Input to cell gate weights
- * param[in] input_to_output_weight Input to output weights
- * param[in] recurrent_to_input_weight Recurrent signal to input weights
- * param[in] recurrent_to_forget_weight Recurrent signal to forget gate weights
- * param[in] recurrent_to_cell_weight Recurrent signal to cell gate weighst
- * param[in] recurrent_to_output_weight Recurrent signal to output weights
- * param[in] lstm LSTM parameters
- * param[in] n_batch Batch size
- * param[in] n_cell Cell size
- * param[in] n_input Input size
- * param[in] n_output Output size
- * param[out] output_state Output state
- * param[out] cell_state Internal state
- * param[out] output Output signal
- * param[in] *scratch_buffers Struct containing scratch buffers
- */
-arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
- const int8_t *input_to_input_weight,
- const int8_t *input_to_forget_weight,
- const int8_t *input_to_cell_weight,
- const int8_t *input_to_output_weight,
- const int8_t *recurrent_to_input_weight,
- const int8_t *recurrent_to_forget_weight,
- const int8_t *recurrent_to_cell_weight,
- const int8_t *recurrent_to_output_weight,
- const cmsis_nn_lstm_params *lstm,
- const int n_batch,
- const int n_cell,
- const int n_input,
- const int n_output,
- int8_t *output_state,
- int16_t *cell_state,
- int8_t *output,
- cmsis_nn_lstm_context *scratch_buffers);
+ * @param[in] data_in Data input pointervoid
+ * @param[in] hidden_in Hidden state/ recurrent input pointer
+ * @param[out] hidden_out Hidden state/ recurrent output pointer
+ * @param[in] params Struct containg all information about the lstm operator, see
+ * arm_nn_types.
+ * @param[in] buffers Struct containg pointers to all temporary scratch buffers needed for the
+ * lstm operator, see arm_nn_types.
+ * @param[in] batch_offset Number of timesteps between consecutive batches.
+ * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
+ * For params->time major = false, all time steps are stored continously before the next batch, so
+ * batch offset = params->time_steps.
+ * @return The function returns ARM_CMSIS_NN_SUCCESS
-/**
- * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
- *
- * param[in] input Input data
- * param[in] input_to_gate_weights Input to gate weights
- * param[in] input_to_gate_bias Input to gate weights
- * param[in] input_to_gate_scaling Input to gate scaling
- * param[in] activation Actival min and max values
- * param[in] output_state Output state
- * param[in] recurrent_to_gate_weights Recurrent to gate weights
- * param[in] recurrent_to_gate_bias Recurrent to gate bias
- * param[in] recurrent_to_gate_scaling Recurrent to gate scaling
- * param[in] n_batch Batch size
- * param[in] n_input Input size
- * param[out] n_output Output size
- * param[in] activation_type Activation type (sigmoid or tanh)
- * param[out] n_cell Cell size
- */
-void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
- const int8_t *input_to_gate_weights,
- const int32_t *input_to_gate_bias,
- const cmsis_nn_scaling input_to_gate_scaling,
- const int8_t *output_state,
- const int8_t *recurrent_to_gate_weights,
- const int32_t *recurrent_to_gate_bias,
- const cmsis_nn_scaling recurrent_to_gate_scaling,
- const int32_t n_batch,
- const int32_t n_input,
- const int32_t n_output,
- const int32_t n_cell,
- const arm_nn_activation_type activation_type,
- int16_t *gate);
-
-/**
- * @brief Update cell state for a single LSTM iteration step, int8x8_16 version.
- * @param[in] n_block total number of cells for all batches
- * @param[in] cell_state_scale Scaling factor of cell state
- * @param[in] cell_state Input/output vector, size n_batch*n_cell
- * @param[in] input_gate Input vector of size n_block
- * @param[in] forget_gate Input/scratch vector of size n_block, always modified
- * @param[in] cell_gate Input vector of size, n_block
*/
-void arm_nn_lstm_update_cell_state_s16(const int32_t n_block,
- const int32_t cell_state_scale,
- int16_t *cell_state,
- const int16_t *input_gate,
- const int16_t *forget_gate,
- const int16_t *cell_gate);
+arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
+ const int8_t *hidden_in,
+ int8_t *hidden_out,
+ const cmsis_nn_lstm_params *params,
+ cmsis_nn_lstm_context *buffers,
+ const int32_t batch_offset);
/**
- * @brief Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
- *
- * @param[in] n_batch The number of distinct vectors in each array
- * @param[in] n_cell Number of cells
- * @param[in,out] cell_state Cell state, size n_batch*n_cell
- * @param[in] cell_state_scale Scaling of cell_state
- * @param[in] output_gate Output gate
- * @param[in] hidden_scale Effective scaling of cell_state .* output_gate
- * @param[in] hidden_offset Zero point for cell_state .* output_gate
- * @param[out] output_state Output state
- * @param[in] cell_gate_scratch Scratch buffer
+ * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
+ *
+ * @param[in] data_in Data input pointer
+ * @param[in] hidden_in Hidden state/ recurrent input pointer
+ * @param[in] gate_data Struct containing all information about the gate caluclation, see
+ * arm_nn_types.
+ * @param[in] params Struct containing all information about the lstm_operation, see
+ * arm_nn_types
+ * @param[out] output Hidden state/ recurrent output pointer
+ * @param[in] batch_offset Number of timesteps between consecutive batches, see
+ * arm_nn_lstm_step_s8.
+ * @return The function returns ARM_CMSIS_NN_SUCCESS
*/
-void arm_nn_lstm_update_output_s8_s16(const int n_batch,
- const int n_cell,
- int16_t *cell_state,
- const int32_t cell_state_scale,
- const int16_t *output_gate,
- const cmsis_nn_scaling hidden_scale,
- const int32_t hidden_offset,
- int8_t *output_state,
- int16_t *cell_gate_scratch);
+arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in,
+ const int8_t *hidden_in,
+ const cmsis_nn_lstm_gate *gate_data,
+ const cmsis_nn_lstm_params *params,
+ int16_t *output,
+ const int32_t batch_offset);
/**
* @brief The result of the multiplication is accumulated to the passed result buffer.
* Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
* from each other).
*
- * @param[in] lhs_in Batched vector
- * @param[in] rhs_in Weights - input matrix (H(Rows)xW(Columns))
- * @param[in] bias Bias vector
+ * @param[in] lhs Batched vector
+ * @param[in] rhs Weights - input matrix (H(Rows)xW(Columns))
+ * @param[in] effective_bias Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
* @param[out] dst Output
- * @param[in] dst_offset Output offset
* @param[in] dst_multiplier Multiplier for quantization
* @param[in] dst_shift Shift for quantization
* @param[in] rhs_cols Vector/matarix column length
* @param[in] rhs_rows Row count of matrix
- * @param[in] batch Batch size
+ * @param[in] batches Batch size
+ * @param[in] batch_offset Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s8. Note
+ that the output is always stored with sequential batches.
+ * @return The function returns ARM_CMSIS_NN_SUCCESS
+
*/
-void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
- const int8_t *rhs_in,
- const int32_t *bias,
- int16_t *dst,
- const int32_t dst_offset,
- const int32_t dst_multiplier,
- const int32_t dst_shift,
- const int32_t rhs_cols,
- const int32_t rhs_rows,
- const int32_t batch);
+arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs,
+ const int8_t *rhs,
+ const int32_t *effective_bias,
+ int16_t *dst,
+ const int32_t dst_multiplier,
+ const int32_t dst_shift,
+ const int32_t rhs_cols,
+ const int32_t rhs_rows,
+ const int32_t batches,
+ const int32_t batch_offset);
/**
* @brief s16 elementwise multiplication with s8 output
@@ -1598,7 +1533,10 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
* @param[in] out_offset output offset
* @param[in] out_mult output multiplier
* @param[in] out_shift output shift
- * @param[in] block_size number of samples
+ * @param[in] block_size number of samples per batch
+ * @param[in] batch_size number of samples per batch
+ * @param[in] batch_offset Number of timesteps between consecutive batches in output, see
+ * arm_nn_lstm_step_s8. Note that it is assumed that the input is stored with sequential batches.
* @return The function returns ARM_CMSIS_NN_SUCCESS
*
* @details Supported framework: TensorFlow Lite micro
@@ -1609,7 +1547,38 @@ arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
- const int32_t block_size);
+ const int32_t block_size,
+ const int32_t batch_size,
+ const int32_t batch_offset);
+
+/**
+ * @brief s16 elementwise multiplication. The result of the multiplication is accumulated to the passed result buffer.
+ * @param[in] input_1_vect pointer to input vector 1
+ * @param[in] input_2_vect pointer to input vector 2
+ * @param[in] input_1_offset offset for input 1. Not used.
+ * @param[in] input_2_offset offset for input 2. Not used.
+ * @param[in,out] output pointer to output vector
+ * @param[in] out_offset output offset. Not used.
+ * @param[in] out_mult output multiplier
+ * @param[in] out_shift output shift
+ * @param[in] out_activation_min minimum value to clamp output to. Min: -32768
+ * @param[in] out_activation_max maximum value to clamp output to. Max: 32767
+ * @param[in] block_size number of samples
+ * @return The function returns ARM_CMSIS_NN_SUCCESS
+ *
+ * @details Supported framework: TensorFlow Lite micro
+ */
+arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect,
+ const int16_t *input_2_vect,
+ const int32_t input_1_offset,
+ const int32_t input_2_offset,
+ int16_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t out_activation_min,
+ const int32_t out_activation_max,
+ const int32_t block_size);
#ifdef __cplusplus
}
diff --git a/Source/ActivationFunctions/arm_nn_activation_s16.c b/Source/ActivationFunctions/arm_nn_activation_s16.c
index 99c6dbf6..bb05ae4b 100644
--- a/Source/ActivationFunctions/arm_nn_activation_s16.c
+++ b/Source/ActivationFunctions/arm_nn_activation_s16.c
@@ -1,5 +1,6 @@
/*
- * SPDX-FileCopyrightText: Copyright 2010-2020, 2022 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2010-2020, 2022, 2024 Arm Limited and/or its affiliates
+ *
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -18,11 +19,11 @@
/* ----------------------------------------------------------------------
* Project: CMSIS NN Library
- * Title: arm_nn_activations_q15.c
+ * Title: arm_nn_activation_s16.c
* Description: Q15 neural network activation function using direct table look-up
*
- * $Date: 8 September 2022
- * $Revision: V.1.0.0
+ * $Date: 19 January 2024
+ * $Revision: V.2.0.0
*
* Target Processor: Cortex-M cores
*
@@ -47,11 +48,11 @@
*
*/
-void arm_nn_activation_s16(const int16_t *input,
- int16_t *output,
- const uint16_t size,
- const uint16_t left_shift,
- const arm_nn_activation_type type)
+arm_cmsis_nn_status arm_nn_activation_s16(const int16_t *input,
+ int16_t *output,
+ const int32_t size,
+ const int32_t left_shift,
+ const arm_nn_activation_type type)
{
uint32_t abs_input_shift, max_saturation;
switch (type)
@@ -67,18 +68,17 @@ void arm_nn_activation_s16(const int16_t *input,
break;
}
+ const int32_t input_multiplier = (left_shift < 0) ? 3 : 3 << left_shift;
+ const int32_t abs_left_shift = (left_shift < 0) ? -left_shift : 0;
+ const int32_t rounding = (abs_left_shift > 0) ? 1 << (abs_left_shift - 1) : 0;
// Use the LUT for sigmoid and take into account, that
// tanh(x) = 2*sigmoid(2*x) - 1
- int32_t input_multiplier = ((int32_t)3) << left_shift;
for (int i = 0; i < size; ++i, input++, output++)
{
- int32_t input_data = ((*input) * input_multiplier);
-
- uint32_t abs_input_data = input_data > 0 ? input_data : -input_data;
-
- uint32_t uh = abs_input_data >> abs_input_shift;
-
+ const int32_t input_data = ((*input) * input_multiplier + rounding) >> abs_left_shift;
+ const uint32_t abs_input_data = input_data > 0 ? input_data : -input_data;
+ const uint32_t uh = abs_input_data >> abs_input_shift;
uint32_t result;
if (uh >= 255)
@@ -87,8 +87,8 @@ void arm_nn_activation_s16(const int16_t *input,
}
else
{
- uint32_t ua = sigmoid_table_uint16[uh];
- uint32_t ub = sigmoid_table_uint16[uh + 1];
+ const uint32_t ua = sigmoid_table_uint16[uh];
+ const uint32_t ub = sigmoid_table_uint16[uh + 1];
uint32_t ut;
if (type == ARM_SIGMOID)
{
@@ -112,6 +112,8 @@ void arm_nn_activation_s16(const int16_t *input,
}
*output = (int16_t)result;
}
+
+ return ARM_CMSIS_NN_SUCCESS;
}
/**
diff --git a/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c b/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c
new file mode 100644
index 00000000..28277922
--- /dev/null
+++ b/Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c
@@ -0,0 +1,167 @@
+/*
+ * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* ----------------------------------------------------------------------
+ * Project: CMSIS NN Library
+ * Title: arm_elementwise_mul_acc_s16
+ * Description: Accumulative element wise multiplication
+ *
+ * $Date: 19 January 2024
+ * $Revision: V.1.0.0
+ *
+ * Target : Arm(R) M-Profile Architecture
+ *
+ * -------------------------------------------------------------------- */
+
+#include "arm_nnfunctions.h"
+#include "arm_nnsupportfunctions.h"
+
+/**
+ * @ingroup Public
+ */
+
+/**
+ * @addtogroup groupElementwise
+ * @{
+ */
+
+/**
+ * @brief s16 element wise accumulative multiplication of two vectors
+ *
+ * @note Refer header file for details.
+ *
+ */
+arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect,
+ const int16_t *input_2_vect,
+ const int32_t input_1_offset,
+ const int32_t input_2_offset,
+ int16_t *output,
+ const int32_t out_offset,
+ const int32_t out_mult,
+ const int32_t out_shift,
+ const int32_t out_activation_min,
+ const int32_t out_activation_max,
+ const int32_t block_size)
+{
+ (void)input_1_offset;
+ (void)input_2_offset;
+ (void)out_offset;
+ int32_t loop_count;
+
+ const int32_t activation_max = (out_activation_max > 0) ? out_activation_max : NN_Q15_MAX;
+ const int32_t activation_min = (out_activation_max > 0) ? out_activation_min : NN_Q15_MIN;
+
+#if defined(ARM_MATH_MVEI)
+
+ loop_count = block_size;
+
+ while (loop_count > 0)
+ {
+ mve_pred16_t pred = vctp32q(loop_count);
+
+ int32x4_t input_1 = vldrhq_z_s32(input_1_vect, pred);
+ int32x4_t input_2 = vldrhq_z_s32(input_2_vect, pred);
+
+ int32x4_t res_0 = vmulq_s32(input_1, input_2);
+
+ res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift));
+
+ res_0 = vaddq_s32(res_0, vldrhq_z_s32(output, pred));
+
+ res_0 = vmaxq_s32(res_0, vdupq_n_s32(activation_min));
+ res_0 = vminq_s32(res_0, vdupq_n_s32(activation_max));
+
+ vstrhq_p_s32(output, res_0, pred);
+ input_1_vect += 4;
+ input_2_vect += 4;
+
+ output += 4;
+ loop_count -= 4;
+ }
+
+#else
+ int32_t input_1;
+ int32_t input_2;
+ int32_t mul_res;
+ int32_t two_halfword_1, two_halfword_2;
+ int16_t mul_1, mul_2;
+ loop_count = block_size / 2;
+
+ while (loop_count > 0)
+ {
+ two_halfword_1 = arm_nn_read_q15x2_ia(&input_1_vect);
+ two_halfword_2 = arm_nn_read_q15x2_ia(&input_2_vect);
+
+ #if defined(ARM_MATH_DSP)
+ mul_res = SMULBB(two_halfword_1, two_halfword_2);
+ #else
+ input_1 = (int16_t)(two_halfword_1 & 0xFFFF);
+ input_2 = (int16_t)(two_halfword_2 & 0xFFFF);
+ mul_res = input_1 * input_2;
+ #endif
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift);
+ mul_res += output[0];
+
+ mul_res = MAX(mul_res, activation_min);
+ mul_res = MIN(mul_res, activation_max);
+ mul_1 = (int16_t)mul_res;
+
+ #if defined(ARM_MATH_DSP)
+ mul_res = SMULTT(two_halfword_1, two_halfword_2);
+ #else
+ input_1 = (int16_t)(two_halfword_1 >> 16);
+ input_2 = (int16_t)(two_halfword_2 >> 16);
+ mul_res = input_1 * input_2;
+ #endif
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift);
+ mul_res += output[1];
+ mul_res = MAX(mul_res, activation_min);
+ mul_res = MIN(mul_res, activation_max);
+ mul_2 = (int16_t)mul_res;
+
+ arm_nn_write_q15x2_ia(&output, PACK_Q15x2_32x1(mul_1, mul_2));
+
+ loop_count--;
+ }
+ loop_count = block_size & 0x1;
+
+ while (loop_count > 0)
+ {
+
+ input_1 = *input_1_vect++;
+ input_2 = *input_2_vect++;
+
+ mul_res = input_1 * input_2;
+
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift);
+ mul_res += output[0];
+
+ mul_res = MAX(mul_res, activation_min);
+ mul_res = MIN(mul_res, activation_max);
+
+ *output++ = (int16_t)mul_res;
+
+ loop_count--;
+ }
+#endif // #if defined(ARM_MATH_MVEI)
+ return ARM_CMSIS_NN_SUCCESS;
+}
+
+/**
+ * @} end of Doxygen group
+ */
diff --git a/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c b/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c
index 94f19dfd..f53c2e5b 100644
--- a/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c
+++ b/Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -22,7 +22,7 @@
* Description: Elementwise multiplication of 16 bit input with 8 bit output
*
* $Date: 20 January 2023
- * $Revision: V.1.2.0
+ * $Revision: V.2.0.0
*
* Target : Arm(R) M-Profile Architecture
*
@@ -51,70 +51,84 @@ arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
const int32_t out_offset,
const int32_t out_mult,
const int32_t out_shift,
- const int32_t block_size)
+ const int32_t block_size,
+ const int32_t batch_size,
+ const int32_t batch_offset)
{
- int32_t loop_count = block_size;
+ for (int i = 0; i < batch_size; i++)
+ {
+ int32_t loop_count = block_size;
#if defined(ARM_MATH_MVEI)
- while (loop_count > 0)
- {
- mve_pred16_t pred = vctp32q(loop_count);
+ const int16_t *input_1_ptr = input_1_vect;
+ const int16_t *input_2_ptr = input_2_vect;
+ int8_t *output_ptr = output;
- int32x4_t input_1 = vldrhq_z_s32(input_1_vect, pred);
- int32x4_t input_2 = vldrhq_z_s32(input_2_vect, pred);
+ while (loop_count > 0)
+ {
+ mve_pred16_t pred = vctp32q(loop_count);
- int32x4_t res_0 = vmulq_s32(input_1, input_2);
+ int32x4_t input_1 = vldrhq_z_s32(input_1_ptr, pred);
+ int32x4_t input_2 = vldrhq_z_s32(input_2_ptr, pred);
- res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift));
- res_0 = vaddq_n_s32(res_0, out_offset);
+ int32x4_t res_0 = vmulq_s32(input_1, input_2);
- res_0 = vmaxq_s32(res_0, vdupq_n_s32(NN_Q7_MIN));
- res_0 = vminq_s32(res_0, vdupq_n_s32(NN_Q7_MAX));
+ res_0 = arm_requantize_mve_32x4(res_0, vdupq_n_s32(out_mult), vdupq_n_s32(out_shift));
+ res_0 = vaddq_n_s32(res_0, out_offset);
- vstrbq_p_s32(output, res_0, pred);
- input_1_vect += 4;
- input_2_vect += 4;
+ res_0 = vmaxq_s32(res_0, vdupq_n_s32(NN_Q7_MIN));
+ res_0 = vminq_s32(res_0, vdupq_n_s32(NN_Q7_MAX));
- output += 4;
- loop_count -= 4;
- }
+ vstrbq_p_s32(output_ptr, res_0, pred);
+ input_1_ptr += 4;
+ input_2_ptr += 4;
+
+ output_ptr += 4;
+ loop_count -= 4;
+ }
+
+ input_1_vect += block_size;
+ input_2_vect += block_size;
+ output += block_size;
#else
#if defined(ARM_MATH_DSP)
- while (loop_count > 1)
- {
- int32_t input_1 = arm_nn_read_q15x2_ia(&input_1_vect);
- int32_t input_2 = arm_nn_read_q15x2_ia(&input_2_vect);
+ while (loop_count > 1)
+ {
+ int32_t input_1 = arm_nn_read_q15x2_ia(&input_1_vect);
+ int32_t input_2 = arm_nn_read_q15x2_ia(&input_2_vect);
- int32_t mul_res = SMULBB(input_1, input_2);
- mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
- mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
- int32_t mul = (int16_t)(mul_res & 0xFF);
+ int32_t mul_res = SMULBB(input_1, input_2);
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
+ mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
+ int32_t mul = (int16_t)(mul_res & 0xFF);
- mul_res = SMULTT(input_1, input_2);
- mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
- mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
- mul |= (int16_t)mul_res << 8;
+ mul_res = SMULTT(input_1, input_2);
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
+ mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
+ mul |= (int16_t)mul_res << 8;
- arm_nn_write_s8x2_ia(&output, mul);
- loop_count -= 2;
- }
+ arm_nn_write_s8x2_ia(&output, mul);
+ loop_count -= 2;
+ }
#endif
- for (int i = 0; i < loop_count; i++)
- {
- /* C = A * B */
- int32_t mul_res = input_1_vect[i] * input_2_vect[i];
- mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
+ for (int j = 0; j < loop_count; j++, input_1_vect++, input_2_vect++, output++)
+ {
+ /* C = A * B */
+ int32_t mul_res = (*input_1_vect) * (*input_2_vect);
+ mul_res = arm_nn_requantize(mul_res, out_mult, out_shift) + out_offset;
- mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
+ mul_res = CLAMP(mul_res, NN_Q7_MAX, NN_Q7_MIN);
- output[i] = (int8_t)mul_res;
- }
+ *output = (int8_t)mul_res;
+ }
#endif
+ output += (batch_offset - 1) * block_size;
+ }
return ARM_CMSIS_NN_SUCCESS;
}
/**
diff --git a/Source/FullyConnectedFunctions/arm_vector_sum_s8.c b/Source/FullyConnectedFunctions/arm_vector_sum_s8.c
index 0120ba1d..a2564da5 100644
--- a/Source/FullyConnectedFunctions/arm_vector_sum_s8.c
+++ b/Source/FullyConnectedFunctions/arm_vector_sum_s8.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -21,8 +21,8 @@
* Title: arm_vector_sum_s8
* Description: Generic function for calculating vector sums
*
- * $Date: 5 September 2023
- * $Revision: V.1.0.0
+ * $Date: 26 January 2024
+ * $Revision: V.2.0.0
*
* Target : Arm(R) M-Profile Architecture
*
@@ -49,93 +49,122 @@
arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_cols,
const int32_t vector_rows,
- const int8_t *vector_data)
+ const int8_t *vector_data,
+ const int32_t lhs_offset,
+ const int32_t *bias_data)
{
#if defined(ARM_MATH_MVEI)
- const int32_t row_loop_cnt = vector_rows / 4;
-
+ const int32_t row_loop_cnt = vector_rows / 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
-
const int8_t *vector_0 = vector_data;
const int8_t *vector_1 = vector_data + vector_cols;
const int8_t *vector_2 = vector_data + 2 * vector_cols;
const int8_t *vector_3 = vector_data + 3 * vector_cols;
-
+ const int8_t *vector_4 = vector_data + 4 * vector_cols;
int32_t vector_sum_0 = 0;
int32_t vector_sum_1 = 0;
int32_t vector_sum_2 = 0;
int32_t vector_sum_3 = 0;
-
+ int32_t vector_sum_4 = 0;
uint32_t col_cnt = (uint32_t)vector_cols;
-
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
-
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
-
const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
-
const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
-
const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
-
+ const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
+ vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
vector_0 += 16;
vector_1 += 16;
vector_2 += 16;
vector_3 += 16;
+ vector_4 += 16;
+ }
+ vector_data += 5 * vector_cols;
+ if (lhs_offset)
+ {
+ vector_sum_0 *= lhs_offset;
+ vector_sum_1 *= lhs_offset;
+ vector_sum_2 *= lhs_offset;
+ vector_sum_3 *= lhs_offset;
+ vector_sum_4 *= lhs_offset;
+ }
+ if (bias_data)
+ {
+ vector_sum_0 += *bias_data++;
+ vector_sum_1 += *bias_data++;
+ vector_sum_2 += *bias_data++;
+ vector_sum_3 += *bias_data++;
+ vector_sum_4 += *bias_data++;
}
- vector_data += 4 * vector_cols;
-
vector_sum_buf[0] = vector_sum_0;
vector_sum_buf[1] = vector_sum_1;
vector_sum_buf[2] = vector_sum_2;
vector_sum_buf[3] = vector_sum_3;
- vector_sum_buf += 4;
+ vector_sum_buf[4] = vector_sum_4;
+ vector_sum_buf += 5;
}
-
- const int32_t loop_cnt = vector_rows % 4;
-
+ const int32_t loop_cnt = vector_rows % 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
-
const int8_t *vector_0 = vector_data;
-
int32_t vector_sum_0 = 0;
-
uint32_t col_cnt = (uint32_t)vector_cols;
-
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
-
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
-
vector_0 += 16;
}
vector_data += vector_cols;
-
+ if (lhs_offset)
+ {
+ vector_sum_0 *= lhs_offset;
+ }
+ if (bias_data)
+ {
+ vector_sum_0 += *bias_data++;
+ }
vector_sum_buf[i_row_loop_cnt] = vector_sum_0;
}
-
return (ARM_CMSIS_NN_SUCCESS);
#else
- (void)vector_sum_buf;
- (void)vector_rows;
- (void)vector_cols;
- (void)vector_data;
- return (ARM_CMSIS_NN_NO_IMPL_ERROR);
+ if (bias_data)
+ {
+ memcpy(vector_sum_buf, bias_data, vector_rows * sizeof(int32_t));
+ }
+ else
+ {
+ memset(vector_sum_buf, 0, vector_rows * sizeof(int32_t));
+ }
+
+ if (lhs_offset)
+ {
+ for (int i = 0; i < vector_rows; i++)
+ {
+ int32_t sum = 0;
+ for (int j = 0; j < vector_cols; j++)
+ {
+ sum += *vector_data++;
+ }
+ *vector_sum_buf++ += sum * lhs_offset;
+ }
+ }
+ return (ARM_CMSIS_NN_SUCCESS);
+
#endif
}
diff --git a/Source/LSTMFunctions/CMakeLists.txt b/Source/LSTMFunctions/CMakeLists.txt
index e20d0963..eed27265 100644
--- a/Source/LSTMFunctions/CMakeLists.txt
+++ b/Source/LSTMFunctions/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# SPDX-FileCopyrightText: Copyright 2019-2022 Arm Limited and/or its affiliates
+# SPDX-FileCopyrightText: Copyright 2019-2022, 2024 Arm Limited and/or its affiliates
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -16,5 +16,5 @@
# limitations under the License.
#
-file(GLOB SRC "./*_s16.c")
-target_sources(cmsis-nn PRIVATE ${SRC})
+file(GLOB SRC_S8 "./*_s8.c")
+target_sources(cmsis-nn PRIVATE ${SRC_S8})
diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c
new file mode 100644
index 00000000..c2868b3a
--- /dev/null
+++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c
@@ -0,0 +1,80 @@
+/*
+ * SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* ----------------------------------------------------------------------
+ * Project: CMSIS NN Library
+ * Title: arm_lstm_unidirectional_s8.c
+ * Description: S8 LSTM function with S16 gate output
+ *
+ * $Date: 19 January 2024
+ * $Revision: V.1.0.0
+ *
+ * Target Processor: Cortex-M processors
+ *
+ * -------------------------------------------------------------------- */
+
+#include "arm_nnfunctions.h"
+#include "arm_nnsupportfunctions.h"
+/**
+ * @ingroup Public
+ */
+
+/**
+ * @addtogroup LSTM
+ * @{
+ */
+
+/*
+ * S8 LSTM function for TensorFlow Lite with S16 gate output
+ *
+ * Refer to header file for details.
+ *
+ */
+
+arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
+ int8_t *output,
+ const cmsis_nn_lstm_params *params,
+ cmsis_nn_lstm_context *buffers)
+{
+
+ int8_t *hidden_in = NULL;
+ memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
+ const int32_t batch_offset = (params->time_major) ? 1 : params->time_steps;
+
+ for (int t = 0; t < params->time_steps; t++)
+ {
+ const int8_t *data_in = input + (t * params->batch_size * params->input_size);
+ int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size);
+
+ arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, batch_offset);
+
+ if (status != ARM_CMSIS_NN_SUCCESS)
+ {
+ return status;
+ }
+
+ // Output is used as recurrent input/hidden state for the next timestep.
+ hidden_in = &hidden_out[0];
+ }
+
+ return ARM_CMSIS_NN_SUCCESS;
+}
+
+/**
+ * @} end of LSTM group
+ */
diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c
deleted file mode 100644
index 6f658869..00000000
--- a/Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_lstm_unidirectional_s16_s8.c
- * Description: S8 LSTM function with S16 gate output
- *
- * $Date: 4 November 2022
- * $Revision: V.1.0.0
- *
- * Target Processor: Cortex-M processors
- *
- * -------------------------------------------------------------------- */
-
-#include "arm_nnfunctions.h"
-#include "arm_nnsupportfunctions.h"
-
-/**
- * @ingroup Public
- */
-
-/**
- * @addtogroup LSTM
- * @{
- */
-
-/*
- * S8 LSTM function for TensorFlow Lite with S16 gate output
- *
- * Refer to header file for details.
- *
- */
-
-#include "arm_nnfunctions.h"
-#include "arm_nnsupportfunctions.h"
-
-/*
- * LSTM unidirectional function with 8 bit input and output and 16 bit weights
- *
- * Refer header file for details.
- *
- */
-arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers,
- const int8_t *input_data,
- const cmsis_nn_lstm_dims *lstm_dims,
- const int8_t *in_to_in_weights,
- const int8_t *in_to_forget_weights,
- const int8_t *in_to_cell_weights,
- const int8_t *in_to_out_weights,
- const int8_t *recurrent_to_in_weights,
- const int8_t *recurrent_to_forget_weights,
- const int8_t *recurrent_to_cell_weights,
- const int8_t *recurrent_to_out_weights,
- const int16_t *cell_to_in_weights,
- const int16_t *cell_to_forget_weights,
- const int16_t *cell_to_out_weights,
- const int8_t *projection_weights,
- const cmsis_nn_lstm_params *lstm,
- int8_t *output_state,
- int16_t *cell_state,
- int8_t *output_data)
-{
- (void)cell_to_in_weights;
- (void)cell_to_forget_weights;
- (void)cell_to_out_weights;
-
- const int32_t num_batch = lstm_dims->num_batches;
- const int32_t num_input = lstm_dims->num_inputs;
- const int32_t max_time = lstm_dims->max_time;
-
- const int32_t num_output = lstm_dims->num_outputs;
- const int32_t out_batch_leading_dim = num_output;
-
- // num_cell = num_output is considered in the code under the assumption that projection is NULL.
- const int32_t num_cell = num_output;
-
- if (projection_weights != NULL)
- {
- return ARM_CMSIS_NN_ARG_ERROR;
- }
-
- if (lstm->i2f_effective_bias == NULL || lstm->i2c_effective_bias == NULL || lstm->i2o_effective_bias == NULL)
- {
- return ARM_CMSIS_NN_ARG_ERROR;
- }
-
- if (lstm->r2f_effective_bias == NULL || lstm->r2c_effective_bias == NULL || lstm->r2o_effective_bias == NULL)
- {
- return ARM_CMSIS_NN_ARG_ERROR;
- }
-
- if (lstm->i2i_effective_bias == NULL || lstm->r2i_effective_bias == NULL)
- {
- return ARM_CMSIS_NN_ARG_ERROR;
- }
-
- if (lstm->time_major)
- {
- const int32_t in_step = num_batch * num_input;
- const int32_t out_step = num_batch * out_batch_leading_dim;
- for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
- {
- arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + i_max_time * in_step,
- in_to_in_weights,
- in_to_forget_weights,
- in_to_cell_weights,
- in_to_out_weights,
- recurrent_to_in_weights,
- recurrent_to_forget_weights,
- recurrent_to_cell_weights,
- recurrent_to_out_weights,
- lstm,
- num_batch,
- num_cell,
- num_input,
- num_output,
- output_state,
- cell_state,
- output_data + i_max_time * out_step,
- scratch_buffers);
- if (status != ARM_CMSIS_NN_SUCCESS)
- {
- return status;
- }
- }
- }
- else
- {
- for (int i_num_batch = 0; i_num_batch < num_batch; i_num_batch++)
- {
- const int32_t in_step = num_input;
- const int32_t out_step = out_batch_leading_dim;
- for (int i_max_time = 0; i_max_time < max_time; i_max_time++)
- {
- const int32_t time_offset = i_num_batch * max_time + i_max_time;
-
- arm_cmsis_nn_status status = arm_nn_lstm_step_s8_s16(input_data + time_offset * in_step,
- in_to_in_weights,
- in_to_forget_weights,
- in_to_cell_weights,
- in_to_out_weights,
- recurrent_to_in_weights,
- recurrent_to_forget_weights,
- recurrent_to_cell_weights,
- recurrent_to_out_weights,
- lstm,
- /*num_batch=*/1,
- num_cell,
- num_input,
- num_output,
- output_state + i_num_batch * out_batch_leading_dim,
- cell_state + i_num_batch * num_cell,
- output_data + time_offset * out_step,
- scratch_buffers);
- if (status != ARM_CMSIS_NN_SUCCESS)
- {
- return status;
- }
- }
- }
- }
-
- return ARM_CMSIS_NN_SUCCESS;
-}
-
-/**
- * @} end of LSTM group
- */
diff --git a/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c
index 37c7d6d9..b1bf9550 100644
--- a/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c
+++ b/Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -21,8 +21,8 @@
* Title: arm_nn_lstm_calculate_gate_s8_s16.c
* Description: Update single gate for an incremental step of LSTM function.
*
- * $Date: 8 September 2022
- * $Revision: V.1.0.0
+ * $Date: 19 January 2024
+ * $Revision: V.2.0.0
*
* Target Processor: Cortex-M cores
*
@@ -31,7 +31,6 @@
#include "arm_nn_tables.h"
#include "arm_nnfunctions.h"
#include "arm_nnsupportfunctions.h"
-
/**
* @ingroup groupSupport
*/
@@ -52,48 +51,45 @@
* Calculates a single LSTM gate, int8x8_16 version.
* Refer to header file for details
*/
-void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
- const int8_t *input_to_gate_weights,
- const int32_t *input_to_gate_bias,
- const cmsis_nn_scaling input_to_gate_scaling,
- const int8_t *output_state,
- const int8_t *recurrent_to_gate_weights,
- const int32_t *recurrent_to_gate_bias,
- const cmsis_nn_scaling recurrent_to_gate,
- const int32_t n_batch,
- const int32_t n_input,
- const int32_t n_output,
- const int32_t n_cell,
- const arm_nn_activation_type activation_type,
- int16_t *gate)
+arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in,
+ const int8_t *hidden_in,
+ const cmsis_nn_lstm_gate *gate,
+ const cmsis_nn_lstm_params *params,
+ int16_t *output,
+ const int32_t batch_offset)
{
- const int32_t n_block = n_batch * n_cell;
- memset(gate, 0, n_block * sizeof(int16_t));
- arm_nn_vec_mat_mul_result_acc_s8(input,
- input_to_gate_weights,
- input_to_gate_bias,
- gate,
- 0,
- input_to_gate_scaling.multiplier,
- input_to_gate_scaling.shift,
- n_input,
- n_cell,
- n_batch);
+ memset(output, 0, params->hidden_size * params->batch_size * sizeof(int16_t));
+
+ arm_nn_vec_mat_mul_result_acc_s8_s16(data_in,
+ gate->input_weights,
+ gate->input_effective_bias,
+ output,
+ gate->input_multiplier,
+ gate->input_shift,
+ params->input_size,
+ params->hidden_size,
+ params->batch_size,
+ batch_offset);
- arm_nn_vec_mat_mul_result_acc_s8(output_state,
- recurrent_to_gate_weights,
- recurrent_to_gate_bias,
- gate,
- 0,
- recurrent_to_gate.multiplier,
- recurrent_to_gate.shift,
- n_output,
- n_cell,
- n_batch);
+ if (hidden_in)
+ {
+ arm_nn_vec_mat_mul_result_acc_s8_s16(hidden_in,
+ gate->hidden_weights,
+ gate->hidden_effective_bias,
+ output,
+ gate->hidden_multiplier,
+ gate->hidden_shift,
+ params->hidden_size,
+ params->hidden_size,
+ params->batch_size,
+ batch_offset);
+ }
- arm_nn_activation_s16(gate, gate, n_block, 0, activation_type);
+ arm_nn_activation_s16(output, output, params->hidden_size * params->batch_size, 0, gate->activation_type);
+
+ return ARM_CMSIS_NN_SUCCESS;
}
/**
* @} end of supportLSTM group
- */
+ */
\ No newline at end of file
diff --git a/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c b/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c
new file mode 100644
index 00000000..4b25081a
--- /dev/null
+++ b/Source/NNSupportFunctions/arm_nn_lstm_step_s8.c
@@ -0,0 +1,110 @@
+/*
+ * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* ----------------------------------------------------------------------
+ * Project: CMSIS NN Library
+ * Title: arm_nn_lstm_step_s8.c
+ * Description: Update LSTM function for a single iteration step.
+ *
+ * $Date: 19 January 2024
+ * $Revision: V.1.0.0
+ *
+ * Target : Arm(R) M-Profile Architecture
+ *
+ * -------------------------------------------------------------------- */
+#include "arm_nnfunctions.h"
+#include "arm_nnsupportfunctions.h"
+/**
+ * @ingroup groupSupport
+ */
+
+/**
+ * @addtogroup supportLSTM
+ * @{
+ */
+
+/*
+ * Calculate the output state tensor of an LSTM step, s8 input/output/weights and s16 internal buffers version.
+ * Refer to header file for details.
+ */
+arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
+ const int8_t *hidden_in,
+ int8_t *hidden_out,
+ const cmsis_nn_lstm_params *params,
+ cmsis_nn_lstm_context *buffers,
+ const int32_t batch_offset)
+{
+ int16_t *forget_gate = buffers->temp1;
+ int16_t *input_gate = buffers->temp1;
+ int16_t *cell_gate = buffers->temp2;
+ int16_t *output_gate = buffers->temp1;
+ int16_t *hidden_temp = buffers->temp2;
+
+ int16_t *cell_state = buffers->cell_state;
+
+ arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->forget_gate, params, forget_gate, batch_offset);
+
+ // Calculate first term of cell state in place early to maximise reuse of scratch-buffers
+ arm_elementwise_mul_s16(forget_gate,
+ cell_state,
+ 0,
+ 0,
+ cell_state,
+ 0,
+ params->forget_to_cell_multiplier,
+ params->forget_to_cell_shift,
+ NN_Q15_MIN,
+ NN_Q15_MAX,
+ params->hidden_size * params->batch_size);
+
+ arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->input_gate, params, input_gate, batch_offset);
+ arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->cell_gate, params, cell_gate, batch_offset);
+
+ // Reminder of cell state calculation, multiply and add to previous result.
+ arm_elementwise_mul_acc_s16(forget_gate,
+ cell_gate,
+ 0,
+ 0,
+ cell_state,
+ 0,
+ params->input_to_cell_multiplier,
+ params->input_to_cell_shift,
+ -params->cell_clip,
+ params->cell_clip,
+ params->hidden_size * params->batch_size);
+
+ arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, ¶ms->output_gate, params, output_gate, batch_offset);
+
+ // Calculate hidden state directly to output.
+ arm_nn_activation_s16(
+ cell_state, hidden_temp, params->hidden_size * params->batch_size, params->cell_scale_power + 12, ARM_TANH);
+ arm_elementwise_mul_s16_s8(output_gate,
+ hidden_temp,
+ hidden_out,
+ params->output_offset,
+ params->output_multiplier,
+ params->output_shift,
+ params->hidden_size,
+ params->batch_size,
+ batch_offset);
+
+ return ARM_CMSIS_NN_SUCCESS;
+}
+/**
+ * @} end of supportLSTM group
+ */
diff --git a/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c
deleted file mode 100644
index 41e3071f..00000000
--- a/Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_nn_lstm_step_s8_s16.c
- * Description: Update LSTM function for a single iteration step.
- *
- * $Date: 9 Februari 2023
- * $Revision: V.1.1.0
- *
- * Target : Arm(R) M-Profile Architecture
- *
- * -------------------------------------------------------------------- */
-#include "arm_nnsupportfunctions.h"
-/**
- * @ingroup groupSupport
- */
-
-/**
- * @addtogroup supportLSTM
- * @{
- */
-
-/*
- * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
- * Refer to header file for details.
- */
-arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
- const int8_t *input_to_input_weight,
- const int8_t *input_to_forget_weight,
- const int8_t *input_to_cell_weight,
- const int8_t *input_to_output_weight,
- const int8_t *recurrent_to_input_weight,
- const int8_t *recurrent_to_forget_weight,
- const int8_t *recurrent_to_cell_weight,
- const int8_t *recurrent_to_output_weight,
- const cmsis_nn_lstm_params *lstm,
- const int n_batch,
- const int n_cell,
- const int n_input,
- const int n_output,
- int8_t *output_state,
- int16_t *cell_state,
- int8_t *output,
- cmsis_nn_lstm_context *scratch_buffers)
-{
- const int32_t n_block = n_batch * n_cell;
-
- // Calculate the input gate
- arm_nn_lstm_calculate_gate_s8_s16(input,
- input_to_input_weight,
- lstm->i2i_effective_bias,
- lstm->input_to_input_scaling,
- output_state,
- recurrent_to_input_weight,
- lstm->r2i_effective_bias,
- lstm->recurrent_to_input_scaling,
- n_batch,
- n_input,
- n_output,
- n_cell,
- ARM_SIGMOID,
- scratch_buffers->input_gate);
-
- // Calculate the forget gate
- arm_nn_lstm_calculate_gate_s8_s16(input,
- input_to_forget_weight,
- lstm->i2f_effective_bias,
- lstm->input_to_forget_scaling,
- output_state,
- recurrent_to_forget_weight,
- lstm->r2f_effective_bias,
- lstm->recurrent_to_forget_scaling,
- n_batch,
- n_input,
- n_output,
- n_cell,
- ARM_SIGMOID,
- scratch_buffers->forget_gate);
-
- // Calculate the cell update gate
- arm_nn_lstm_calculate_gate_s8_s16(input,
- input_to_cell_weight,
- lstm->i2c_effective_bias,
- lstm->input_to_cell_scaling,
- output_state,
- recurrent_to_cell_weight,
- lstm->r2c_effective_bias,
- lstm->recurrent_to_cell_scaling,
- n_batch,
- n_input,
- n_output,
- n_cell,
- ARM_TANH,
- scratch_buffers->cell_gate);
-
- // Update the cell state
- arm_nn_lstm_update_cell_state_s16(n_block,
- lstm->cell_state_shift,
- cell_state,
- scratch_buffers->input_gate,
- scratch_buffers->forget_gate,
- scratch_buffers->cell_gate);
-
- // Calculate the output gate
- arm_nn_lstm_calculate_gate_s8_s16(input,
- input_to_output_weight,
- lstm->i2o_effective_bias,
- lstm->input_to_output_scaling,
- output_state,
- recurrent_to_output_weight,
- lstm->r2o_effective_bias,
- lstm->recurrent_to_output_scaling,
- n_batch,
- n_input,
- n_output,
- n_cell,
- ARM_SIGMOID,
- scratch_buffers->output_gate);
-
- // Update the output state
- arm_nn_lstm_update_output_s8_s16(n_batch,
- n_cell,
- cell_state,
- lstm->cell_state_shift,
- scratch_buffers->output_gate,
- lstm->hidden_scaling,
- lstm->hidden_offset,
- output_state,
- scratch_buffers->input_gate);
-
- arm_memcpy_s8(output, output_state, n_batch * n_output * sizeof(int8_t));
-
- return ARM_CMSIS_NN_SUCCESS;
-}
-/**
- * @} end of supportLSTM group
- */
diff --git a/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c
deleted file mode 100644
index b142773e..00000000
--- a/Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_nn_lstm_update_cell_state_s16.c
- * Description: Update cell state for an incremental step of LSTM function.
- *
- * $Date: 20 January 2023
- * $Revision: V.1.2.0
- *
- * Target : Arm(R) M-Profile Architecture
- *
- * -------------------------------------------------------------------- */
-
-#include "arm_nnsupportfunctions.h"
-/**
- * @ingroup groupSupport
- */
-
-/**
- * @addtogroup supportLSTM
- * @{
- */
-
-/*
- * Update cell state for a single LSTM iteration step, int8x8_16 version.
- *
- * Refer to header file for more details
- */
-void arm_nn_lstm_update_cell_state_s16(const int32_t n_block,
- const int32_t cell_state_scale,
- int16_t *cell_state,
- const int16_t *input_gate,
- const int16_t *forget_gate,
- const int16_t *cell_gate)
-{
- const int32_t cell_scale = 30 + cell_state_scale;
- int32_t loop_count = n_block;
-
-#if defined(ARM_MATH_MVEI)
-
- while (loop_count > 0)
- {
- mve_pred16_t p = vctp32q(loop_count);
- loop_count -= 4;
-
- int32x4_t res_1 = vmulq_s32(vldrhq_z_s32(cell_state, p), vldrhq_z_s32(forget_gate, p));
- forget_gate += 4;
- res_1 = arm_divide_by_power_of_two_mve(res_1, 15);
- int32x4_t res_2 = vmulq_s32(vldrhq_z_s32(input_gate, p), vldrhq_z_s32(cell_gate, p));
- input_gate += 4;
- cell_gate += 4;
-
- res_2 = arm_divide_by_power_of_two_mve(res_2, cell_scale);
- res_1 += res_2;
-
- res_1 = vmaxq_s32(res_1, vdupq_n_s32(NN_Q15_MIN));
- res_1 = vminq_s32(res_1, vdupq_n_s32(NN_Q15_MAX));
-
- vstrhq_p_s32(cell_state, res_1, p);
- cell_state += 4;
- }
-#else
- #if defined(ARM_MATH_DSP)
- while (loop_count > 1)
- {
- int32_t cell_state_01 = arm_nn_read_s16x2(cell_state);
- int32_t forget_gate_01 = arm_nn_read_q15x2_ia(&forget_gate);
-
- int32_t value_00 = SMULBB(cell_state_01, forget_gate_01);
- int32_t value_01 = SMULTT(cell_state_01, forget_gate_01);
- value_00 = arm_nn_divide_by_power_of_two(value_00, 15);
- value_01 = arm_nn_divide_by_power_of_two(value_01, 15);
-
- int32_t input_gate_01 = arm_nn_read_q15x2_ia(&input_gate);
- int32_t cell_gate_01 = arm_nn_read_q15x2_ia(&cell_gate);
-
- int32_t value_10 = SMULBB(input_gate_01, cell_gate_01);
- int32_t value_11 = SMULTT(input_gate_01, cell_gate_01);
-
- value_10 = arm_nn_divide_by_power_of_two(value_10, cell_scale);
- value_11 = arm_nn_divide_by_power_of_two(value_11, cell_scale);
-
- value_00 += value_10;
- value_01 += value_11;
-
- value_00 = CLAMP(value_00, NN_Q15_MAX, NN_Q15_MIN);
- value_01 = CLAMP(value_01, NN_Q15_MAX, NN_Q15_MIN);
-
- arm_nn_write_q15x2_ia(&cell_state, PACK_Q15x2_32x1(value_00, value_01));
- loop_count -= 2;
- }
- #endif
- for (int i = 0; i < loop_count; i++)
- {
- int32_t value = cell_state[i] * forget_gate[i];
- int32_t value_1 = input_gate[i] * cell_gate[i];
-
- value = arm_nn_divide_by_power_of_two(value, 15);
- value_1 = arm_nn_divide_by_power_of_two(value_1, cell_scale);
-
- cell_state[i] = CLAMP(value + value_1, NN_Q15_MAX, NN_Q15_MIN);
- }
-#endif // #if defined(ARM_MATH_MVEI)
-}
-/**
- * @} end of supportLSTM group
- */
diff --git a/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c b/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c
deleted file mode 100644
index c742f354..00000000
--- a/Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/* ----------------------------------------------------------------------
- * Project: CMSIS NN Library
- * Title: arm_nn_lstm_update_output_s8_s16.c
- * Description: Update output gate for an incremental step of LSTM function.
- *
- * $Date: 13 Februari 2023
- * $Revision: V.2.0.0
- *
- * Target : Arm(R) M-Profile Architecture
- *
- * -------------------------------------------------------------------- */
-
-#include "arm_nnfunctions.h"
-#include "arm_nnsupportfunctions.h"
-
-/**
- * @ingroup groupSupport
- */
-
-/**
- * @addtogroup supportLSTM
- * @{
- */
-
-/*
- * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
- * Refer to header files for details
- */
-void arm_nn_lstm_update_output_s8_s16(const int n_batch,
- const int n_cell,
- int16_t *cell_state,
- const int32_t cell_state_scale,
- const int16_t *output_gate,
- const cmsis_nn_scaling hidden_scaling,
- const int32_t hidden_offset,
- int8_t *output_state,
- int16_t *cell_gate_scratch)
-{
- const int32_t size = n_batch * n_cell;
-
- int32_t tanh_input_left_shift = (15 + cell_state_scale) - 3;
- if (tanh_input_left_shift < 0)
- {
- tanh_input_left_shift = -tanh_input_left_shift;
- for (int32_t i = 0; i < size; i++)
- {
- cell_state[i] = cell_state[i] >> tanh_input_left_shift;
- }
- tanh_input_left_shift = 0;
- }
- arm_nn_activation_s16(cell_state, cell_gate_scratch, size, tanh_input_left_shift, ARM_TANH);
-
- arm_elementwise_mul_s16_s8(output_gate,
- cell_gate_scratch,
- output_state,
- hidden_offset,
- hidden_scaling.multiplier,
- hidden_scaling.shift,
- size);
-}
-/**
- * @} end of supportLSTM group
- */
diff --git a/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c b/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c
similarity index 60%
rename from Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c
rename to Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c
index d1eba6e2..17b9a63f 100644
--- a/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c
+++ b/Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -18,18 +18,16 @@
/* ----------------------------------------------------------------------
* Project: CMSIS NN Library
- * Title: arm_nn_vec_mat_mul_result_acc_s8.c
+ * Title: arm_nn_vec_mat_mul_result_acc_s8_s16.c
* Description: Multiplies a matrix by a vector and accumulate with output.
*
- * $Date: 20 January 2023
- * $Revision: V.1.2.0
+ * $Date: 19 January 2024
+ * $Revision: V.2.0.0
*
* Target : Arm(R) M-Profile Architecture
*
* -------------------------------------------------------------------- */
-
#include "arm_nnsupportfunctions.h"
-
/**
* @ingroup groupSupport
*/
@@ -42,41 +40,47 @@
/*
* Refer to header file for details.
*/
-void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
- const int8_t *rhs_in,
- const int32_t *bias,
- int16_t *dst,
- const int32_t dst_offset,
- const int32_t dst_multiplier,
- const int32_t dst_shift,
- const int32_t rhs_cols,
- const int32_t rhs_rows,
- const int32_t batch)
+arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs,
+ const int8_t *rhs,
+ const int32_t *effective_bias,
+ int16_t *dst,
+ const int32_t dst_multiplier,
+ const int32_t dst_shift,
+ const int32_t rhs_cols,
+ const int32_t rhs_rows,
+ const int32_t batches,
+ const int32_t batch_offset)
{
- for (int i_batch = 0; i_batch < batch; ++i_batch)
+
+ for (int batch = 0; batch < batches; batch++)
{
- const int8_t *rhs = rhs_in;
- const int8_t *lhs = lhs_in + i_batch * rhs_cols;
+ const int8_t *rhs_ptr = &rhs[0];
+ const int32_t *effective_bias_ptr = &effective_bias[0];
#if defined(ARM_MATH_MVEI)
- const int32_t row_loop_cnt = rhs_rows / 4;
- for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
+ for (size_t row_loop_cnt = rhs_rows / 4; row_loop_cnt != 0; --row_loop_cnt)
{
- int32_t acc_0 = 0;
- int32_t acc_1 = 0;
- int32_t acc_2 = 0;
- int32_t acc_3 = 0;
+ const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
const int8_t *lhs_vec = lhs;
- const int8_t *rhs_0 = rhs;
- const int8_t *rhs_1 = rhs + rhs_cols;
- const int8_t *rhs_2 = rhs + 2 * rhs_cols;
- const int8_t *rhs_3 = rhs + 3 * rhs_cols;
-
- int32_t col_cnt = rhs_cols;
-
- while (col_cnt > 0)
+ const int8_t *rhs_0 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_1 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_2 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_3 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+
+ int32_t acc_0 = *effective_bias_ptr++;
+ int32_t acc_1 = *effective_bias_ptr++;
+ int32_t acc_2 = *effective_bias_ptr++;
+ int32_t acc_3 = *effective_bias_ptr++;
+
+ uint32_t col_cnt = (uint32_t)rhs_cols;
+
+ for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
@@ -84,16 +88,16 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
- acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
+ acc_0 = vmladavaq_s8(acc_0, ker_0, input);
const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
- acc_1 = vmladavaq_p_s8(acc_1, ker_1, input, p);
+ acc_1 = vmladavaq_s8(acc_1, ker_1, input);
const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
- acc_2 = vmladavaq_p_s8(acc_2, ker_2, input, p);
+ acc_2 = vmladavaq_s8(acc_2, ker_2, input);
const int8x16_t ker_3 = vldrbq_z_s8(rhs_3, p);
- acc_3 = vmladavaq_p_s8(acc_3, ker_3, input, p);
+ acc_3 = vmladavaq_s8(acc_3, ker_3, input);
lhs_vec += 16;
rhs_0 += 16;
@@ -101,16 +105,10 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
rhs_2 += 16;
rhs_3 += 16;
}
- rhs += 4 * rhs_cols;
int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
- int32x4_t b = vldrwq_s32(bias);
- acc = vaddq_s32(acc, b);
- bias += 4;
acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
- acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
-
acc = vaddq_s32(acc, vldrhq_s32(dst));
acc = vmaxq_s32(acc, vdupq_n_s32(NN_Q15_MIN));
@@ -120,79 +118,77 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
dst += 4;
}
- const int loop_cnt = rhs_rows % 4;
- for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
+ for (size_t row_loop_cnt = rhs_rows % 4; row_loop_cnt != 0; --row_loop_cnt)
{
- int32_t acc_0 = 0;
+ int32_t acc_0 = *effective_bias_ptr++;
+
+ const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
const int8_t *lhs_vec = lhs;
- const int8_t *rhs_0 = rhs;
- int32_t col_cnt = rhs_cols;
+ const int8_t *rhs_0 = rhs_ptr;
+ uint32_t col_cnt = (uint32_t)rhs_cols;
- while (col_cnt > 0)
+ for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
- acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
+ acc_0 = vmladavaq_s8(acc_0, ker_0, input);
lhs_vec += 16;
rhs_0 += 16;
}
- rhs += rhs_cols;
-
- acc_0 += *bias;
- bias++;
+ rhs_ptr += rhs_cols;
acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
- acc_0 += dst_offset + *dst;
+ acc_0 += *dst;
// Clamp the result
- acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN);
+ acc_0 = MAX(acc_0, NN_Q15_MIN);
+ acc_0 = MIN(acc_0, NN_Q15_MAX);
*dst++ = (int16_t)acc_0;
}
#elif defined(ARM_MATH_DSP)
- const int32_t row_loop_cnt = rhs_rows / 2;
- for (int32_t i = 0; i < row_loop_cnt; i++)
+ for (int32_t row_loop_cnt = rhs_rows / 2; row_loop_cnt != 0; --row_loop_cnt)
{
- int32_t acc_0 = *bias++;
- int32_t acc_1 = *bias++;
+ int32_t acc_0 = *effective_bias_ptr++;
+ int32_t acc_1 = *effective_bias_ptr++;
const int32_t col_loop_cnt = rhs_cols / 4;
const int8_t *lhs_vec = lhs;
- const int8_t *rhs_0 = rhs;
- const int8_t *rhs_1 = rhs + rhs_cols;
- rhs += 2 * rhs_cols;
+ const int8_t *rhs_0 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_1 = rhs_ptr;
+ rhs_ptr += rhs_cols;
for (int j = col_loop_cnt; j != 0; j--)
{
int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
int32_t vec_1 = SXTB16_RORn((uint32_t)vec_0, 8);
-
vec_0 = SXTB16(vec_0);
int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
- acc_0 = SMLAD(ker_1, vec_1, acc_0);
-
ker_0 = SXTB16(ker_0);
+
+ acc_0 = SMLAD(ker_1, vec_1, acc_0);
acc_0 = SMLAD(ker_0, vec_0, acc_0);
ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
- acc_1 = SMLAD(ker_1, vec_1, acc_1);
-
ker_0 = SXTB16(ker_0);
+
+ acc_1 = SMLAD(ker_1, vec_1, acc_1);
acc_1 = SMLAD(ker_0, vec_0, acc_1);
}
for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
{
- const int32_t lhs_temp = *lhs_vec;
+ const int32_t lhs_temp = (*lhs_vec);
lhs_vec++;
acc_0 += lhs_temp * (*rhs_0);
rhs_0++;
@@ -204,24 +200,26 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
// Add offset
- acc_0 += dst_offset + *dst;
- acc_1 += dst_offset + dst[1];
+ acc_0 += *dst;
// Clamp the result
- acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN);
- acc_1 = CLAMP(acc_1, NN_Q15_MAX, NN_Q15_MIN);
-
+ acc_0 = MAX(acc_0, NN_Q15_MIN);
+ acc_0 = MIN(acc_0, NN_Q15_MAX);
*dst++ = (int16_t)acc_0;
+
+ acc_1 += *dst;
+ acc_1 = MAX(acc_1, NN_Q15_MIN);
+ acc_1 = MIN(acc_1, NN_Q15_MAX);
+
*dst++ = (int16_t)acc_1;
}
if (rhs_rows & 0x1)
{
- int32_t acc_0 = *bias++;
-
+ int32_t acc_0 = *effective_bias_ptr++;
const int32_t col_loop_cnt = rhs_cols / 4;
const int8_t *lhs_vec = lhs;
- const int8_t *rhs_0 = rhs;
+ const int8_t *rhs_0 = rhs_ptr;
for (int i = col_loop_cnt; i != 0; i--)
{
@@ -239,16 +237,19 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
{
- const int32_t lhs_temp = *lhs_vec++;
- acc_0 += lhs_temp * (*rhs_0++);
+ const int32_t lhs_temp = (*lhs_vec);
+ lhs_vec++;
+ acc_0 += lhs_temp * (*rhs_0);
+ rhs_0++;
}
acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
- // Add offset
- acc_0 += dst_offset + *dst;
+ // Accumulate
+ acc_0 += dst[0];
// Clamp the result
- acc_0 = CLAMP(acc_0, NN_Q15_MAX, NN_Q15_MIN);
+ acc_0 = MAX(acc_0, NN_Q15_MIN);
+ acc_0 = MIN(acc_0, NN_Q15_MAX);
*dst++ = (int16_t)acc_0;
}
@@ -259,13 +260,16 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
const int8_t *lhs_ptr = lhs;
- const int8_t *rhs_ptr_0 = &rhs[0];
- const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
- const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
+ const int8_t *rhs_ptr_0 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_ptr_1 = rhs_ptr;
+ rhs_ptr += rhs_cols;
+ const int8_t *rhs_ptr_2 = rhs_ptr;
+ rhs_ptr += rhs_cols;
- int32_t res00 = *bias++;
- int32_t res01 = *bias++;
- int32_t res02 = *bias++;
+ int32_t res00 = *effective_bias_ptr++;
+ int32_t res01 = *effective_bias_ptr++;
+ int32_t res02 = *effective_bias_ptr++;
for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
{
@@ -289,21 +293,17 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
// Add offset
- res00 += dst_offset + *dst;
- res01 += dst_offset + dst[1];
- res02 += dst_offset + dst[2];
-
- // Clamp the result
+ res00 += (int32_t)*dst;
res00 = CLAMP(res00, NN_Q15_MAX, NN_Q15_MIN);
- res01 = CLAMP(res01, NN_Q15_MAX, NN_Q15_MIN);
- res02 = CLAMP(res02, NN_Q15_MAX, NN_Q15_MIN);
+ *dst++ = (int16_t)res00;
- dst[0] = (int16_t)res00;
- dst[1] = (int16_t)res01;
- dst[2] = (int16_t)res02;
- dst += 3;
+ res01 += (int32_t)*dst;
+ res01 = CLAMP(res01, NN_Q15_MAX, NN_Q15_MIN);
+ *dst++ = (int16_t)res01;
- rhs += 3 * rhs_cols;
+ res02 += (int32_t)*dst;
+ res02 = CLAMP(res02, NN_Q15_MAX, NN_Q15_MIN);
+ *dst++ = (int16_t)res02;
}
const int loop_cnt = rhs_rows % 3;
@@ -311,37 +311,41 @@ void arm_nn_vec_mat_mul_result_acc_s8(const int8_t *lhs_in,
for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
{
const int8_t *lhs_ptr = &lhs[0];
- const int8_t *rhs_ptr = &rhs[0];
+ const int8_t *rhs_ptr_0 = &rhs_ptr[0];
- int32_t res00 = *bias++;
+ int32_t res00 = *effective_bias_ptr++;
for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
{
- int32_t rhs_value0 = (int8_t)rhs_ptr[0];
+ int32_t rhs_value0 = (int8_t)rhs_ptr_0[0];
int32_t lhs_value = (int8_t)lhs_ptr[0];
res00 += lhs_value * rhs_value0;
- ++rhs_ptr;
+ ++rhs_ptr_0;
++lhs_ptr;
}
// Quantize down
res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
- // Add offset
- res00 += dst_offset + *dst;
+ // Accumulate
+ res00 += (int32_t)dst[0];
// Clamp the result
res00 = CLAMP(res00, NN_Q15_MAX, NN_Q15_MIN);
*dst++ = (int16_t)res00;
- rhs += rhs_cols;
+ rhs_ptr += rhs_cols;
}
#endif
+
+ lhs += rhs_cols * batch_offset;
}
+
+ return ARM_CMSIS_NN_SUCCESS;
}
/**
* @} end of supportLSTM group
- */
+ */
\ No newline at end of file
diff --git a/Tests/UnitTest/CMakeLists.txt b/Tests/UnitTest/CMakeLists.txt
index c6c0ef03..d8d8ab7a 100644
--- a/Tests/UnitTest/CMakeLists.txt
+++ b/Tests/UnitTest/CMakeLists.txt
@@ -97,7 +97,7 @@ add_subdirectory(TestCases/test_arm_fully_connected_s16)
add_subdirectory(TestCases/test_arm_fully_connected_s8)
add_subdirectory(TestCases/test_arm_fully_connected_s4)
add_subdirectory(TestCases/test_arm_grouped_convolve_s8)
-add_subdirectory(TestCases/test_arm_lstm_unidirectional_s16_s8)
+add_subdirectory(TestCases/test_arm_lstm_unidirectional_s8)
add_subdirectory(TestCases/test_arm_max_pool_s16)
add_subdirectory(TestCases/test_arm_max_pool_s8)
add_subdirectory(TestCases/test_arm_softmax_s16)
diff --git a/Tests/UnitTest/README.md b/Tests/UnitTest/README.md
index 5ccf6ff3..0863bfdb 100644
--- a/Tests/UnitTest/README.md
+++ b/Tests/UnitTest/README.md
@@ -61,7 +61,7 @@ Python package tflite_runtime can be installed with pip and it can also be built
Use the -h flag to get more info on supported interpreters.
##### tflite_micro
-This interpreter is partially supported. See this comment for more info: https://github.com/tensorflow/tflite-micro/issues/1484#issuecomment-1677842603.
+Python package tflite_micro can be installed with pip and it can also be built locally. See this comment for more info: https://github.com/tensorflow/tflite-micro/issues/1484#issuecomment-1677842603. This interpreter is only partially supported, see *Tests depending on TFLM interpreter*.
## Getting started
@@ -126,20 +126,14 @@ When adding a new test data set, new c files should be added or existing c files
The steps to add a new unit test are as follows. Add a new test test in the load_all_testdatasets() function. Run the generate script with that new test set as input. Add the new generated header files to an existing or new unit test.
-### Tests depending on specific TFL versions, patched TFL version or TFLM interpreter
-
+### Tests depending on TFLM interpreter
#### SVDF INT8
-This tests is depending on tflite_micro for its reference data. This is because the operator is only supported by TFLM.
-Note that tflite_micro interpreter is currently only supported for SVDF.
+This test is depending on tflite_micro for its reference data. This is because the operator is only supported by TFLM.
#### LSTM
+This test is depending on tflite_micro for its reference data. This is because the operator differs between TFLM and TFLite.
-The LSTM tests are using the tflite_runtime as interpreter.
-See [Using tflite_runtime](https://github.com/ARM-software/CMSIS-NN/blob/main/Tests/UnitTest/README.md#using-tflite_runtime) for more info.
-This patch is needed for the tflite_runtime (or tensorflow if using that):
-https://github.com/tensorflow/tflite-micro/pull/1253 - Note that this PR is for [TFLM](https://github.com/tensorflow/tflite-micro) so it has to be ported to [TFL](https://github.com/tensorflow/tensorflow) before building the tflite_runtime.
-The issue related to this is: https://github.com/tensorflow/tflite-micro/issues/1455
-
+Note that tflite_micro interpreter is currently only supported for SVDF and LSTM.
## Overview of the Folders
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h
index a9cbb2db..43c4374a 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_1_cell_gate_bias[11] = {6795, 15999, -6130, -12546, 10504, -8988, -30870, -19325, 70, 25310, 4688};
+const int32_t lstm_1_cell_gate_bias[11] = {-16190, 6797, 24062, 29971, -22780, 17656, 14698, 1849, 4054, 14590, -20709};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h
index aec53e24..062bb680 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h
index 1cbf293a..43589949 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_state_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h
index ef4bc138..c7df28f9 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_forget_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h
index 5a6c3d29..a27ae43e 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_input_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h
index 5012f469..30aaf78d 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/cell_to_output_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h
index 36400519..28263946 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/config_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#define LSTM_1_BUFFER_SIZE 11
#define LSTM_1_INPUT_BATCHES 1
@@ -10,24 +10,29 @@
#define LSTM_1_TIME_MAJOR 1
#define LSTM_1_IN_ACTIVATION_MIN -32768
#define LSTM_1_IN_ACTIVATION_MAX 32767
-#define LSTM_1_IN_TO_INPUT_MULTIPLIER 1075971584
+#define LSTM_1_IN_TO_INPUT_MULTIPLIER 1075906048
#define LSTM_1_IN_TO_INPUT_SHIFT -2
-#define LSTM_1_IN_TO_FORGET_MULTIPLIER 1083903104
+#define LSTM_1_IN_TO_FORGET_MULTIPLIER 1085883136
#define LSTM_1_IN_TO_FORGET_SHIFT -2
-#define LSTM_1_IN_TO_CELL_MULTIPLIER 1082296832
+#define LSTM_1_IN_TO_CELL_MULTIPLIER 1084231552
#define LSTM_1_IN_TO_CELL_SHIFT -2
-#define LSTM_1_IN_TO_OUTPUT_MULTIPLIER 1080149504
+#define LSTM_1_IN_TO_OUTPUT_MULTIPLIER 1085274240
#define LSTM_1_IN_TO_OUTPUT_SHIFT -2
-#define LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER 1164713600
+#define LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER 1523696256
#define LSTM_1_RECURRENT_TO_INPUT_SHIFT -2
-#define LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER 1155545344
+#define LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER 1511291392
#define LSTM_1_RECURRENT_TO_FORGET_SHIFT -2
-#define LSTM_1_RECURRENT_TO_CELL_MULTIPLIER 1154073088
+#define LSTM_1_RECURRENT_TO_CELL_MULTIPLIER 1523716992
#define LSTM_1_RECURRENT_TO_CELL_SHIFT -2
-#define LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER 1082469760
+#define LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER 1525092864
#define LSTM_1_RECURRENT_TO_OUTPUT_SHIFT -2
-#define LSTM_1_HIDDEN_MULTIPLIER 1993146898
+#define LSTM_1_FORGET_MULTIPLIER 1073741824
+#define LSTM_1_FORGET_SHIFT -14
+#define LSTM_1_INPUT_MULTIPLIER 1073741824
+#define LSTM_1_INPUT_SHIFT -17
+#define LSTM_1_HIDDEN_MULTIPLIER 1522160019
#define LSTM_1_HIDDEN_SHIFT -22
-#define LSTM_1_HIDDEN_OFFSET 66
-#define LSTM_1_OUTPUT_STATE_OFFSET 66
+#define LSTM_1_HIDDEN_OFFSET -23
+#define LSTM_1_DATA_OFFSET 128
+#define LSTM_1_OUTPUT_STATE_OFFSET -23
#define LSTM_1_CELL_STATE_SHIFT -12
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h
index 2edce470..f7adf487 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_gate_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_forget_gate_bias[11] =
- {-25627, -3119, 1396, -28731, 20399, -16645, -15073, 19313, 8401, -21968, -27812};
+ {-23170, -13466, -6110, 22504, -22652, 25549, -26211, -32267, 15774, -29318, 6943};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h
index 6b56429e..e70544d5 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/forget_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h
index 9a71b026..ad595f7e 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_data.h
@@ -1,17 +1,17 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_input[220] = {
- 72, 126, -20, 19, 79, 71, 8, -72, -127, 118, 91, -27, 52, 121, 74, -4, 64, -14, -15, -105,
- -56, 36, -8, 49, -63, -119, -4, -31, -28, -8, -97, 82, -85, 68, -28, 98, -1, -65, 3, -99,
- -41, -11, -65, -55, 88, 66, -56, -2, -33, 108, -109, -105, 91, -101, -96, 119, -2, -109, -5, -29,
- 24, 43, -38, -128, -85, -114, -116, 12, -41, -117, 109, 104, 34, -49, -54, 96, 30, -28, -72, 3,
- 101, -84, -71, 1, 105, -65, -33, -49, 55, -49, 110, -92, 77, -93, 102, 28, 101, -113, -120, -88,
- 119, -124, -47, 98, -114, -122, 124, 57, -36, 18, 70, -72, -84, -75, -36, -110, -118, -1, 88, 67,
- -99, -115, 70, -16, -85, 49, 89, -104, 42, 86, -103, 20, -50, 113, -88, -81, 57, 45, 0, -33,
- 48, -27, 34, 73, 84, -64, 80, -115, 75, 58, -9, 32, 111, -64, -12, -40, 117, -97, -16, -33,
- 44, -51, 12, 15, -66, 82, -7, -38, 77, 116, 60, -39, 19, -29, 47, 84, -72, -109, -29, -86,
- -30, -69, -78, 90, -40, 104, -36, -110, 26, -93, -24, -4, -18, 61, 65, 77, 6, -40, 89, -92,
- 116, -89, 112, -8, -68, 28, 98, -56, 108, 63, 109, -72, -101, -19, -58, 86, -73, -46, -70, -89};
+ 121, -117, 50, -102, -68, -103, 9, -120, 48, 53, 25, 69, -77, 15, -34, 32, -70, 29, -118, -58,
+ -92, -17, -109, 96, -113, -5, 87, 45, -116, 98, 98, -41, 31, 40, 27, 65, -115, 35, 47, -96,
+ 20, 30, 101, 70, -16, 102, 70, -117, -22, -45, 49, 29, -38, 111, 35, 76, -91, 71, 1, 23,
+ -69, -32, -85, 13, 39, -32, -123, 69, 24, 110, -1, 22, 112, -79, 73, 125, 68, 53, -16, 100,
+ 24, 20, -61, 79, 88, 94, 8, -61, 26, 13, 21, 99, -119, -12, -100, -46, 79, 10, 26, 96,
+ -78, 75, -18, -42, 0, 65, -124, 43, 91, 42, -62, -95, 80, 19, -85, -42, -128, -106, -21, -90,
+ -75, 66, 94, -10, -122, -55, -8, -126, 51, 8, -21, 71, -93, 67, -23, 88, 3, -50, 16, 124,
+ -9, -69, 114, 73, 116, 113, 83, -3, -101, 34, -113, -18, -101, 13, -48, 48, 63, 56, -114, 80,
+ 8, -20, -23, 63, 115, -83, -59, -72, -21, -57, -97, -6, 100, 82, 118, 83, 126, -111, 36, -114,
+ 89, -38, 54, 113, 36, -1, -119, -122, -23, -71, -82, 87, 44, 101, -13, -3, -102, -97, 42, 44,
+ 61, 31, -89, -6, -33, -57, 16, -16, -27, -20, -52, -104, 120, 30, -125, -9, 96, 67, 77, 125};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h
index 0f945278..ed45a535 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_gate_bias_data.h
@@ -1,7 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_1_input_gate_bias[11] =
- {13255, 22061, -13996, 8755, -14507, 4146, 10874, -20199, -24289, -30146, 28463};
+const int32_t lstm_1_input_gate_bias[11] = {-32410, -104, 21567, -21097, 12535, 259, 8401, 10604, 24974, 30367, -9986};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h
index f47f1480..0b9ac91c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h
index f0f82cbe..a8c39211 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_input_to_cell_eff_bias[11] =
- {-23285, 35199, 3342, -20866, 1928, -43932, 15850, -46717, -20794, 54878, 11600};
+ {-70974, 15629, 50174, -12269, -26876, 14328, 56938, -90951, 37206, 49150, 37019};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h
index b2d99b5c..bce86cc3 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_cell_w_data.h
@@ -1,19 +1,18 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_input_to_cell_w[242] = {
- 88, 6, 55, 61, -92, 45, 32, -22, -111, 74, -20, 84, 127, -83, 65, -70, -58, -73, -49,
- -114, -115, -65, -64, 119, 121, -35, -69, -25, 5, -38, 46, 27, -80, 41, 83, 47, 4, 32,
- -95, -56, 72, 49, -100, 66, 68, 98, 45, -124, -126, 32, 45, -51, 105, -16, 32, 60, 23,
- 109, -104, -50, 47, -98, -8, 71, -100, 16, 80, 33, -10, -47, -48, 52, -35, 85, 16, 22,
- 109, 72, -78, -76, 49, -82, 7, -70, 0, -105, -39, 0, 38, 79, -47, 116, -23, -53, -33,
- -113, -98, -48, -3, 55, 16, -62, -3, 41, -60, 39, 19, 88, -110, 95, 103, 80, -4, -16,
- -92, -51, -78, -80, 112, 102, -116, -2, -91, -76, 89, 24, 99, -17, -118, 32, -46, -127, -7,
- 33, 52, 41, -51, 19, 73, 78, 13, 111, 99, -50, -78, -23, 99, -27, 90, -59, 1, 1,
- 27, -77, 76, -100, 14, -12, 112, -70, -56, -33, -14, -79, 46, -81, -41, 37, 97, -41, -98,
- 38, -120, 25, -3, 89, -111, -101, 123, -80, -94, 0, 97, -120, 62, 116, -8, 18, -37, -104,
- 57, -14, 1, -121, 89, -75, 113, 26, 95, -37, 45, -97, 35, 41, 62, 23, -70, 124, -94,
- 81, 125, 127, -51, -76, 17, 49, -80, -78, 76, -86, 103, 34, -69, -26, 50, -55, -66, 56,
- -52, -63, 48, -93, 120, 15, -77, 80, 27, 89, -37, -87, -23, 80};
+ -112, 13, 117, -90, -36, -7, -41, 19, -90, 44, 121, -52, -58, -77, -18, 126, -63, -105, -45, -53, 88,
+ -109, -99, -26, -36, 64, -3, -127, -109, 114, 59, 108, -56, 26, 102, -100, -48, -11, -90, 81, 49, 94,
+ -30, 107, 63, 65, 66, -74, -24, -1, -57, -109, -26, -75, 48, 93, -51, 8, 76, 94, 66, -76, -58,
+ -15, 98, 93, -95, 45, -23, 104, 60, -13, -35, -73, 10, -8, -51, -43, -100, 55, 101, -53, -85, 12,
+ -93, 0, 14, -59, -49, -3, 3, -74, -107, -124, -85, 72, -30, 104, -57, 18, 114, -93, -72, 22, 65,
+ 120, -34, -5, 120, 63, 3, -93, 69, 27, -124, -13, -93, 84, 32, -5, 88, -49, 93, 2, -72, -56,
+ 101, -30, 21, -1, -64, 54, 108, -86, -12, -88, 102, 39, 80, -12, 39, -1, 108, 10, -28, 31, -62,
+ 5, -107, 102, 56, -26, 3, 69, 115, -103, -73, -52, -96, 27, 102, -127, -118, -22, 45, 20, -60, -53,
+ -20, -47, -52, 43, -66, -114, -74, 0, -66, 80, 103, 28, 84, -17, -85, -81, 10, 109, -5, -2, -76,
+ 17, 85, -17, -51, 92, 76, -21, 7, -11, -87, 48, -27, -9, 124, 15, 120, -123, 64, -71, -59, 32,
+ -35, 90, 112, 67, 13, 13, 6, -1, 67, -89, 7, 93, 18, 125, 121, 48, -13, -91, 41, 30, 31,
+ -91, 65, 114, -98, -88, -6, -48, 18, 50, 115, 10};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h
index 5c84d6b1..aba8cb81 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_input_to_forget_eff_bias[11] =
- {-58523, -69423, 59764, -63035, -12625, -52357, 2463, 52593, -6703, -113488, 114652};
+ {19454, 742, -39262, 13544, -61308, 48333, 43165, -115211, -17250, -34822, 33823};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h
index c89123b8..fb6dab68 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_forget_w_data.h
@@ -1,19 +1,19 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_input_to_forget_w[242] = {
- -98, -98, -102, 30, -76, 42, -40, 98, -37, 29, -43, 60, -37, 26, 114, -37, 100, -85, -62,
- 86, -71, -56, -83, -35, 105, 60, 17, -21, 109, -77, -122, -64, -113, -53, 49, -102, -84, -89,
- 101, 61, -99, -47, -64, 33, 120, -10, 64, 122, 71, 77, -15, 106, -9, -12, -50, 46, 49,
- 98, -126, -11, 81, -121, 52, -119, -5, 48, 17, -62, 20, -125, 39, 63, -126, -109, -102, -81,
- -31, 49, 84, 64, -45, 100, -1, -76, 103, -84, 19, 16, -110, 18, 71, -125, 66, -46, 45,
- 111, -4, 54, -106, -83, -42, 24, -26, -62, -46, -20, 70, -124, -1, 78, 8, -59, -87, 104,
- -54, 53, -38, 11, 104, -5, -113, 36, -62, -15, -117, -39, 48, -7, -44, 61, 22, -86, 50,
- 105, 72, -91, 29, -19, 19, 64, 17, -3, -110, -5, 122, -79, 79, 43, -15, -23, -69, -13,
- 11, -47, 19, -104, 82, 40, -6, 115, 83, 101, 112, -80, 18, -49, -70, 83, 98, -34, 75,
- -80, 91, -42, -113, -79, -103, 56, 25, -49, -106, -65, 123, -90, -127, 121, 27, -40, 82, -40,
- -13, -106, 100, 18, -28, -5, 86, 16, 64, 50, -71, -118, 2, -32, -2, 94, -125, -98, 4,
- -49, -26, -127, -101, 17, 121, -91, -11, -28, -114, -74, 118, 79, -27, 50, 120, 126, 91, -108,
- 102, 15, 88, 37, -7, 41, -90, 118, 84, -1, 68, 22, 88, 99};
+ 63, 63, 113, 40, -74, -51, -21, 79, -91, 95, 110, -57, -95, -80, -26, -40, 0, 87, -99,
+ 102, 93, 122, -96, 62, 14, 40, -90, 14, 125, 3, -2, 50, 113, -37, -109, 24, -22, 38,
+ 29, -80, 21, -80, 71, 23, -109, 107, 118, -109, -5, -62, 112, -24, 41, -24, -75, 57, -2,
+ 118, 111, -112, -125, -3, -22, -17, -123, -111, 114, -96, 66, -95, -123, 126, 20, -22, -47, 72,
+ -118, 50, -122, 18, 99, 70, -126, 3, 96, -74, -60, 79, -40, -115, -84, -83, 90, -77, -113,
+ 60, 10, -111, 40, 69, -89, -83, -34, 14, 3, -97, 97, 71, 72, 98, 73, 20, -31, -5,
+ 27, -79, -95, -44, -48, 123, 12, -95, -92, 17, 119, 25, 77, 127, 56, -4, 18, -23, 74,
+ 73, 12, 82, 69, 105, -106, 6, 51, 21, 32, 28, 116, 123, -118, -124, -47, -12, 66, 65,
+ 69, -43, -122, -67, 80, -112, -46, -24, -47, -1, -28, -73, -113, 116, -18, -81, -33, -18, 105,
+ -116, -71, -68, 63, 26, -98, -96, -46, 68, 64, 82, -10, 55, -95, -89, 62, -13, -84, 46,
+ -23, -75, 89, -41, 53, -21, -51, -35, 56, -116, 13, -74, 73, -75, 32, -55, -81, 74, 91,
+ 110, -78, 67, 22, 23, -92, 3, 66, -98, 87, -91, 89, -93, 37, 118, 124, 34, -61, 78,
+ 85, 18, -17, -14, -21, 3, -77, -63, -105, 126, 83, -78, 26, -82};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h
index 5d88fdfa..6773e43c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_input_to_input_eff_bias[11] =
- {73287, 57517, -29484, -49741, 12885, -64462, 41210, -61287, 48415, -27970, 67375};
+ {-123034, -15592, 20415, -22377, 65527, -19069, 31697, -19220, 96526, 111007, 17534};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h
index acf45cc6..3431c7c9 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_input_w_data.h
@@ -1,19 +1,19 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_input_to_input_w[242] = {
- 69, 103, 18, 41, 114, 100, 107, 9, -38, 70, 15, 41, 54, 51, -106, 37, -25, -71, -101,
- 102, -85, -36, 106, -24, 120, 37, -86, 53, 17, 50, -94, -22, 42, 6, 109, -82, -27, -28,
- -66, 92, 47, 14, 127, -114, -83, -125, 93, -114, 7, 90, 57, -54, -110, -98, 103, 116, 70,
- 41, -57, -29, -80, 89, -102, -115, 56, 124, -52, -19, 63, 107, 101, 26, -26, -88, -6, -109,
- -21, 107, -21, 8, -40, -118, 30, -66, -118, 13, -113, -115, -55, -15, -71, 12, 120, 86, 107,
- -55, 17, -3, -99, 19, -85, -34, 25, 58, -83, 95, 24, 14, 67, 70, -12, -99, -99, -104,
- 40, 38, -28, -123, 44, -90, -95, 25, 117, 36, -44, 87, -28, 39, -102, -106, 90, -122, 39,
- -120, -2, 50, 84, -126, 115, 31, -58, 19, 85, 46, 88, -123, 55, 77, -9, -21, 94, 27,
- -90, -24, -67, -64, -22, 119, 46, 1, -15, 50, -19, -18, 21, -81, -71, -118, -113, -73, 44,
- -26, 48, 22, 78, -63, 107, -81, -79, -34, 70, 82, 94, -23, -46, 104, 38, -3, 36, 103,
- -4, 92, 45, -77, 70, 8, -34, 100, 15, -36, 83, -14, -17, 29, 106, 24, 91, 101, -33,
- -109, -53, 32, -41, -20, 72, 31, -118, -50, -79, 3, 91, -99, 124, 117, -16, -35, 65, -18,
- -115, 71, 33, -43, 74, -98, -37, 26, 124, -73, -19, -44, 66, 110};
+ -83, -99, -16, -47, -67, 80, -5, -85, -11, -108, 1, -44, -66, -100, 127, -123, 56, -65, 13,
+ 93, -124, -35, -93, -67, 36, 71, -50, -75, -48, 2, -97, 5, 45, 91, 75, 2, -83, -112,
+ 85, -65, 107, 62, 25, -37, -111, -74, -80, -26, -18, 76, -15, 50, -120, 75, 104, -76, 21,
+ 31, 65, -30, -116, 117, 123, 105, -42, -68, 31, -41, -58, 114, -28, 46, -103, -80, 34, 89,
+ 5, -90, 11, 86, -9, -102, 90, -54, 27, -29, 15, 36, 26, 16, 27, 64, 3, 51, 69,
+ 61, -123, 118, 23, -23, 60, 125, -87, -52, -108, 83, 94, 4, 93, -110, -35, -96, 124, 73,
+ -24, -121, -88, -62, 7, 87, 111, 38, 123, -88, 59, -27, -106, -60, -9, -112, 70, -15, -31,
+ -5, 88, 124, -44, 78, -81, 39, 78, -67, -113, -17, 81, 10, 8, 2, -8, 6, 40, 5,
+ -14, 3, 108, 69, 39, 64, 85, 64, 21, 52, -115, -99, -20, -61, -102, -122, 55, 18, -98,
+ -36, 0, -95, -78, 18, -81, -4, 33, 78, -23, 30, 124, 41, 71, 106, 101, 59, -81, -46,
+ -26, 92, 49, -6, 97, -37, -111, 93, 67, 126, 63, -36, 12, -29, 100, 44, 16, 77, -47,
+ 105, -54, 115, 17, 19, 99, 37, 44, -66, 29, -108, -53, 87, -30, -77, -44, 123, -111, 71,
+ -42, 34, 14, 91, 12, 118, 74, 64, -106, 106, -1, 6, -102, -19};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h
index e3c21161..9ba660c8 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_input_to_output_eff_bias[11] =
- {49470, -17340, 19595, 64019, -58643, 48766, -43591, -26418, 70460, 25909, -56678};
+ {-37894, 15877, 85451, -15991, 37318, 38600, -31336, -20295, -76464, 56669, -2404};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h
index 3672d460..ef8e04f1 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/input_to_output_w_data.h
@@ -1,19 +1,19 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_input_to_output_w[242] = {
- 25, -27, -110, 8, 91, -32, -98, 74, 81, -32, 106, -29, -57, -41, -17, 95, 25, 39, 124,
- -45, 61, -41, -37, -23, -125, 116, -110, 49, 28, -27, -34, 79, 9, 108, 27, -14, 82, -113,
- -82, -62, 55, -17, -30, -28, 22, 50, 97, 5, -72, 27, -36, 29, 20, -69, -79, -13, 46,
- 89, 56, 93, 42, -35, -45, 37, -58, -113, -94, 63, 76, -119, 58, 24, 127, 103, -88, -82,
- -37, 98, 104, 67, 18, 113, -65, -26, 39, 43, 127, -42, -88, 65, -62, -107, 64, -22, 23,
- -40, 93, 107, -3, -92, -10, -22, -69, 122, -51, 10, 0, 19, -102, -106, -93, -17, 23, 41,
- 59, 29, 124, 79, 74, 63, 114, 96, -85, 42, 112, -78, -6, -115, 9, -127, -80, -123, 60,
- 95, -42, -42, -114, -113, -100, -112, -28, 8, 48, -33, -48, 111, 55, 3, -80, 4, 42, -19,
- 50, 93, 39, 8, 94, 115, 12, -99, -118, -95, -23, -42, 30, 32, -58, -101, 69, -115, -1,
- -72, -122, 57, -11, 101, 60, -99, 28, 26, 84, -33, -126, -9, -104, -59, -29, 122, -30, 120,
- -33, 117, 27, 59, 104, 55, -21, 66, 67, -64, 52, 78, -120, -29, -88, -74, -6, 44, -101,
- -125, 123, -4, 14, 88, -17, 103, -6, 2, 11, 37, -22, -84, -125, 67, -121, 54, 23, -51,
- 113, -83, 15, -67, 10, 109, -88, 43, 32, -28, -79, 84, -74, 74};
+ -59, 85, -25, 59, -52, -15, 21, -113, -114, -7, -16, -78, -58, -59, -20, -46, 24, 4, -69,
+ 56, 25, -13, 107, 48, -102, 24, -41, -75, 97, 70, 44, 17, -116, 54, -101, 10, 8, -36,
+ -66, -44, -111, 114, 22, 70, 62, 48, -16, -6, 67, 121, 55, -72, 120, 58, 80, -105, -98,
+ 51, 93, -115, 110, 79, -4, -18, -51, 11, 117, -62, -4, 93, 23, 32, -60, 42, -121, 54,
+ 54, -83, 63, 55, -35, -8, -76, -114, 24, -108, 81, 45, 44, -15, 100, 52, -22, 25, 23,
+ -48, -38, -30, -127, 69, 18, -4, 84, 97, 16, 61, -52, 20, 70, 39, 69, -78, 78, -43,
+ -110, 95, 94, -50, 77, -5, -78, 21, 23, -32, 46, 40, -105, 17, 65, 40, 12, 0, 30,
+ -88, 119, 123, -120, -118, 21, 33, -58, 73, -117, -52, -10, -93, 72, -16, 49, -114, -94, 10,
+ 15, -78, -120, 10, -90, -9, -94, -28, -76, 80, -91, -63, -10, 78, -24, 104, 71, -17, 27,
+ 122, -5, 46, 27, -120, -21, -51, -29, -63, -67, 96, -115, 24, -117, -127, -52, 59, 73, 3,
+ 49, 33, -47, -20, -12, -16, 28, -4, -79, -94, -96, 123, -120, -57, 13, -18, 28, -48, 68,
+ 102, 24, 114, -118, 124, 60, -41, 97, 78, 11, 40, 38, 120, -91, -86, 9, -94, 120, -60,
+ 98, -119, -50, 49, -43, -82, -6, -38, 113, -98, -25, 113, -27, -101};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h
index f0cc061f..324e2dbc 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_gate_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_output_gate_bias[11] =
- {23870, 1732, 7691, -877, -23955, 30718, -22855, 11982, 28860, 27829, -31334};
+ {22266, 16773, 25291, -17527, -11578, 16072, 21528, 3001, -28336, 29661, 30876};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h
index 4fd42944..58db3fa0 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h
index a01813a8..d0b3ac83 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_ref_data.h
@@ -1,11 +1,11 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_output_ref[110] = {
- 127, 124, 81, 62, 72, 51, 115, 63, -83, 119, 79, 115, 127, 94, 106, 54, 68, 104, 63, -77, 116, 85,
- 127, 127, 94, 127, 67, 108, 73, 57, -35, 120, 103, 48, 127, 44, 97, 16, 64, 89, 62, 85, 101, 74,
- -50, 88, 23, 60, 40, 47, 69, 55, 104, 99, 90, -101, 95, 33, 59, 71, 59, 76, 41, -54, 121, 127,
- -106, 109, 21, 41, 50, 52, 92, 40, -34, 78, 88, -84, 115, -7, 41, 58, 55, 117, 39, 109, 104, 89,
- -103, 115, 1, 56, 36, 55, 90, 44, 62, 107, 87, -68, 87, -44, 29, 47, 25, 76, 29, -41, 108, 89};
+ -24, -26, 42, -30, -91, 13, 9, -29, -16, 74, -38, -24, 37, 46, -38, -25, -10, -9, -37, -12, 110, 47,
+ -25, 52, 113, -35, -15, -22, 36, -41, -16, 101, 81, -26, 78, 119, -27, -83, -54, 14, -31, -19, 127, 79,
+ -27, 98, 118, -25, 66, -48, 61, -36, -9, 127, 85, -29, 100, 104, -71, 36, -24, 40, -28, 21, 121, 60,
+ -26, 113, 127, -39, -50, -33, 15, -47, -3, 120, 52, -31, 118, 127, -59, 34, -65, 42, -42, -11, 100, 53,
+ -26, 84, 67, -47, 36, -65, 29, -40, -16, 58, 101, -29, 85, 116, -73, -12, -115, 32, -41, -1, 104, 56};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h
index 3855269c..8c1d1a91 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/output_state_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-int8_t lstm_1_output_state[11] = {66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66};
+int8_t lstm_1_output_state[11] = {-23, -23, -23, -23, -23, -23, -23, -23, -23, -23, -23};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h
index 15335fa2..2cdf2f40 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_bias_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h
index c8b1358f..daa15c33 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/projection_weights_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h
index a50291e8..4ad4495f 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_cell_w_data.h
@@ -1,13 +1,12 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_recurrent_input_to_cell_w[121] = {
- 111, 31, -22, 113, 102, -33, -15, 117, 58, 88, 6, -12, -19, 102, 37, 80, -109, -9,
- -83, 93, 110, -45, 57, -122, 0, -120, 57, 61, -102, -47, -7, -101, -126, 93, -13, 117,
- -51, -72, -33, 76, 92, -105, 0, -92, -56, -81, 37, 21, -124, 34, -78, -120, 30, -70,
- 53, 113, 57, -12, -11, 95, 117, 56, 47, -68, 110, 41, -72, -32, -85, 122, -17, -49,
- 111, -8, -55, -107, 40, -111, -94, 40, 111, 72, -16, 39, -123, -119, 60, 118, 62, -55,
- -100, -60, 6, -74, 113, -80, -72, -88, -18, -27, 11, 62, -99, 127, -54, 12, 41, 13,
- -76, -19, 7, 62, -29, -11, -94, -74, -53, -110, -25, -32, 74};
+ 119, -7, -15, -96, -39, 87, -17, 17, 77, -113, 50, -26, 91, 53, 14, 56, 84, 73, -61, 118, -9,
+ -50, -72, 30, -72, -18, 41, 13, 14, 41, 26, 5, -40, -15, 94, -16, 80, -11, -66, -30, 25, -115,
+ -51, -93, -5, 41, 96, 97, -52, -101, 54, 16, -73, 96, -55, -63, -127, -9, -83, -114, 7, 32, 70,
+ -17, -70, -40, 116, -10, -9, 20, -36, -14, 40, 25, 127, -63, -18, 74, -51, -125, -53, 119, -83, 75,
+ 0, -77, -60, -10, 74, -54, -37, 92, -7, -105, -38, 52, 23, 20, 52, 39, -74, 50, -4, 62, -87,
+ 64, -62, 72, 37, -68, -14, 92, -111, -80, 24, 18, -86, -59, -32, 125, 15};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h
index 05bcd09b..f6ad4f00 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_forget_w_data.h
@@ -1,12 +1,12 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_recurrent_input_to_forget_w[121] = {
- -68, 67, 106, -51, 107, -86, -60, -25, -67, -44, -127, -1, 60, 73, -127, 91, -20, 95, -80, -108, 55,
- -40, -31, -73, 41, -101, -14, 85, 113, 0, 15, -71, -5, -64, 111, -98, -66, 109, 6, 44, -103, 6,
- 4, 104, 109, -29, 118, -55, 22, -35, 83, 6, -4, -65, -14, 102, 91, 121, 57, -120, -55, -60, -33,
- -79, 15, -103, 53, -122, -78, -54, -8, 115, -67, 27, -1, 74, -114, -22, -82, 28, -18, 32, -21, 84,
- -66, 14, -82, 53, -25, -26, -126, -73, -46, -86, 16, -118, 22, -42, 57, -90, 121, -80, 109, -55, -75,
- -94, -85, 80, 121, 45, -33, 9, 58, -95, 29, 73, -65, 33, 10, 66, 36};
+ 110, -108, -7, -84, 96, 93, -39, -46, -107, 73, 66, -74, 107, 49, 14, 27, 37, -74, -97, -69, 22,
+ 78, 47, 124, -103, -74, -120, -107, 71, 3, 81, 127, 86, 69, -68, 7, -119, -67, -9, 8, 60, 78,
+ 33, -59, 26, 64, 13, -51, 7, 0, 118, 37, 123, -104, -71, -18, 113, -107, -40, 33, 23, -42, -40,
+ 74, -57, 50, -22, -45, -120, 68, -11, -6, 36, 103, -73, 51, -3, -41, -22, -92, 92, -20, 122, -127,
+ -54, 73, 6, -47, -75, -74, 118, -19, -12, -97, 42, 31, -59, -45, -36, -108, 41, 91, -49, -1, -9,
+ 79, 68, 36, 13, -66, 102, 126, -3, -17, 58, -87, 120, 6, 68, 95, 65};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h
index 86a2a835..ba796abf 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_input_w_data.h
@@ -1,12 +1,12 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_recurrent_input_to_input_w[121] = {
- -75, -104, 8, 76, -10, -67, -120, -91, -15, 110, 48, 103, 85, -29, 34, -111, 98, 70, -67, 24, 72,
- -101, -87, 26, -61, 122, 14, 52, 99, 113, 127, 77, 6, 118, -88, -18, -94, 14, -38, 91, -2, -72,
- -86, 81, -122, 109, -19, -61, -11, -14, 116, 78, 4, 125, 73, 56, -66, -49, -23, 100, -102, 27, 75,
- 118, -124, -113, -16, 60, -30, -96, 47, 119, -63, 104, -123, -91, -97, -39, -127, -25, 1, 44, 8, -21,
- -25, 4, -52, -82, 107, -60, -125, -82, 64, 101, 24, -67, -27, -52, 119, 97, 26, -69, 92, 77, 16,
- -15, -11, -34, -8, -94, 6, -32, -60, -94, -50, 125, -2, -73, -118, -33, 96};
+ 98, -54, 29, 95, -124, -26, 17, 6, 17, 65, -105, 24, -101, -115, 49, 15, -32, 100, 51, -54, 125,
+ -43, -47, 39, -76, -114, 61, -106, 38, -115, 7, 38, 111, 1, -28, 117, 93, 121, -103, -37, -88, 121,
+ 95, 65, 50, 121, -66, 28, -125, -30, -19, -15, -70, -59, -119, -107, -43, 122, -68, -123, -101, -22, 127,
+ 81, -63, -126, -96, 74, -111, -102, -127, 61, -12, -60, 91, -109, 24, 3, -86, 6, -19, -51, -48, -113,
+ -50, 16, -5, -31, -76, -14, -24, 47, 112, -6, -19, 54, 51, 21, 122, 117, 118, -9, 11, 121, -84,
+ -74, 86, -90, 66, 51, -19, 10, 94, -60, 75, 57, -84, -30, -22, -3, 55};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h
index d0a3c658..62b117c4 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_input_to_output_w_data.h
@@ -1,13 +1,12 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_1_recurrent_input_to_output_w[121] = {
- -110, 10, -9, -112, -13, 8, 110, 122, 49, 124, -29, -20, -8, -66, -66, -60, -71, 10,
- 25, -107, -63, -43, -74, -41, 27, 127, 42, 9, -6, 97, -33, 71, -86, 115, 88, 32,
- -75, -117, 70, -108, 54, 42, 92, -73, -79, -123, -7, 58, -44, 77, -11, -84, -102, -8,
- 77, 42, -21, -100, 64, 43, 115, -42, 117, 66, 1, 100, -21, -122, -45, 57, 9, -36,
- -28, -94, -94, -90, -118, -123, -30, 42, 21, -56, -7, -125, -79, -16, 114, -68, 90, 69,
- -99, -43, 63, 31, -38, -22, 43, -52, -33, 37, -24, -81, 4, 88, 124, 57, 76, -73,
- 80, 22, -10, 10, 3, 32, -37, 43, -104, -113, -75, 81, -66};
+ 9, -100, 91, -34, 118, -95, 69, 47, 8, 79, 67, 6, 66, 122, -52, -3, 11, 57, -93, -70, 6,
+ 48, 7, 90, 44, -18, 74, -114, 57, -5, -110, 76, -117, -121, 59, -113, 68, 94, -64, 75, -81, 17,
+ -85, -65, -85, -13, -40, -5, -121, 26, -29, 54, -113, 37, -6, 127, -67, -12, 127, -23, -85, -32, 38,
+ 23, -123, 58, 108, -80, 68, -120, -101, 114, -107, 4, 84, 40, 85, -3, 123, 104, 125, -55, 91, 34,
+ -120, -80, -106, 9, -110, 48, 49, -36, 53, -94, 23, -23, -65, -26, -58, 82, -21, -20, 94, -31, 86,
+ 12, -89, 114, -62, 126, 27, 8, 45, -54, -122, 95, 7, -83, 19, 24, 44};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h
index 8dcc2b5e..06b79d01 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_cell_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_recurrent_to_cell_eff_bias[11] =
- {-36696, -9570, 29700, -792, 23364, -35970, 10032, 1518, 24156, 594, 18810};
+ {1449, 7889, -736, -4554, 2622, -9522, 4094, -4393, 1656, 667, -2484};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h
index 26a79590..599b9096 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_forget_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_recurrent_to_forget_eff_bias[11] =
- {16368, 132, 2706, -3498, -8976, 4224, 11550, 5280, 29502, 198, -7986};
+ {1081, 460, 3105, -1541, 3726, -253, -506, -2530, -5198, 2185, 12259};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h
index 0d542ac8..3de075dc 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_input_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_recurrent_to_input_eff_bias[11] =
- {15840, -11748, -32208, 6204, -18348, 6666, 12276, 20724, -132, -5082, 15510};
+ {414, 437, -3772, 8211, -6992, -7429, -8441, -8694, 6164, 7199, 1679};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h
index 2e395eff..2f4d8378 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/recurrent_to_output_eff_bias_data.h
@@ -1,7 +1,7 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int32_t lstm_1_recurrent_to_output_eff_bias[11] =
- {-9900, 30954, -8778, -7920, 16236, -25410, 38412, 21582, -594, -20460, 15576};
+ {5957, 2254, -368, -4968, -6785, 713, 2185, 2806, -5497, 6693, 230};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h
index ecd25010..f1ab6ed2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_1/test_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#include "cell_gate_bias_data.h"
#include "cell_norm_coeff_data.h"
#include "cell_state_data.h"
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h
index 16a9bd84..7c52e627 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_cell_gate_bias[7] = {-18223, -7792, 2026, 24820, 5477, -24022, -32276};
+const int32_t lstm_2_cell_gate_bias[7] = {16402, 13081, 21684, 28843, 19606, 13836, -23310};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h
index 84fb2981..711d66c2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h
index 25d429c2..1f6e2166 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_state_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-int16_t lstm_2_cell_state[14] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+int16_t lstm_2_cell_state[7] = {0, 0, 0, 0, 0, 0, 0};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h
index 62963262..6341b7a2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_forget_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h
index 9abfc468..0500eabb 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_input_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h
index b5d67ed2..2f7d5c88 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/cell_to_output_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h
index 2d125a88..da8a0fb4 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/config_data.h
@@ -1,33 +1,38 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
-#define LSTM_2_BUFFER_SIZE 14
-#define LSTM_2_INPUT_BATCHES 2
-#define LSTM_2_DST_SIZE 126
+#define LSTM_2_BUFFER_SIZE 7
+#define LSTM_2_INPUT_BATCHES 1
+#define LSTM_2_DST_SIZE 63
#define LSTM_2_TIME_STEPS 9
#define LSTM_2_NUMBER_UNITS 7
#define LSTM_2_NUMBER_INPUTS 6
#define LSTM_2_TIME_MAJOR 0
#define LSTM_2_IN_ACTIVATION_MIN -32768
#define LSTM_2_IN_ACTIVATION_MAX 32767
-#define LSTM_2_IN_TO_INPUT_MULTIPLIER 2143970816
+#define LSTM_2_IN_TO_INPUT_MULTIPLIER 2146016128
#define LSTM_2_IN_TO_INPUT_SHIFT -3
-#define LSTM_2_IN_TO_FORGET_MULTIPLIER 2037819904
+#define LSTM_2_IN_TO_FORGET_MULTIPLIER 2131870080
#define LSTM_2_IN_TO_FORGET_SHIFT -3
-#define LSTM_2_IN_TO_CELL_MULTIPLIER 2115009024
-#define LSTM_2_IN_TO_CELL_SHIFT -3
-#define LSTM_2_IN_TO_OUTPUT_MULTIPLIER 2088505856
-#define LSTM_2_IN_TO_OUTPUT_SHIFT -3
-#define LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER 1470197120
+#define LSTM_2_IN_TO_CELL_MULTIPLIER 1079288192
+#define LSTM_2_IN_TO_CELL_SHIFT -2
+#define LSTM_2_IN_TO_OUTPUT_MULTIPLIER 1077017984
+#define LSTM_2_IN_TO_OUTPUT_SHIFT -2
+#define LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER 1351661952
#define LSTM_2_RECURRENT_TO_INPUT_SHIFT -3
-#define LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER 1504444160
+#define LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER 1367683200
#define LSTM_2_RECURRENT_TO_FORGET_SHIFT -3
-#define LSTM_2_RECURRENT_TO_CELL_MULTIPLIER 1454177792
+#define LSTM_2_RECURRENT_TO_CELL_MULTIPLIER 1374072192
#define LSTM_2_RECURRENT_TO_CELL_SHIFT -3
-#define LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER 1455750656
+#define LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER 1379062144
#define LSTM_2_RECURRENT_TO_OUTPUT_SHIFT -3
-#define LSTM_2_HIDDEN_MULTIPLIER 1537938273
+#define LSTM_2_FORGET_MULTIPLIER 1073741824
+#define LSTM_2_FORGET_SHIFT -14
+#define LSTM_2_INPUT_MULTIPLIER 1073741824
+#define LSTM_2_INPUT_SHIFT -17
+#define LSTM_2_HIDDEN_MULTIPLIER 1667670651
#define LSTM_2_HIDDEN_SHIFT -21
-#define LSTM_2_HIDDEN_OFFSET 74
-#define LSTM_2_OUTPUT_STATE_OFFSET 74
-#define LSTM_2_CELL_STATE_SHIFT -13
+#define LSTM_2_HIDDEN_OFFSET -24
+#define LSTM_2_DATA_OFFSET 128
+#define LSTM_2_OUTPUT_STATE_OFFSET -24
+#define LSTM_2_CELL_STATE_SHIFT -12
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h
index 020999f7..69bd9b74 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_forget_gate_bias[7] = {-15296, 3060, 33125, -4714, -14467, 22732, -4683};
+const int32_t lstm_2_forget_gate_bias[7] = {-25576, -22402, 15452, -12818, 31310, 27442, 4985};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h
index 794b3962..558c2b81 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/forget_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h
index 5c1bb385..8ab020fa 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_data.h
@@ -1,12 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_input[108] = {-123, -88, 71, -16, -31, -9, 3, 71, 41, -43, 69, -51, -74, 31, -127, -3,
- 63, 32, -23, -75, -95, 80, -96, -113, -4, -13, -6, -94, 6, 19, -80, 3,
- 85, 73, -43, -17, 96, 67, 7, 31, -128, 117, -113, 24, -90, 113, -36, -13,
- 108, -54, 115, -106, -3, -58, 122, 68, -52, -2, -46, -2, 9, 3, -122, 71,
- 40, 111, 103, 2, 120, -53, 26, 125, 41, -94, 107, 51, 74, 7, 15, -27,
- 9, 33, -117, -41, 50, 106, 66, -62, 117, -93, -107, -47, -96, 124, -6, -90,
- 7, -3, 13, -18, 99, -56, -37, 104, 66, 99, -1, -32};
+const int8_t lstm_2_input[54] = {-77, 86, 115, -22, 114, -93, -95, 61, 100, 107, 33, -79, -11, -63,
+ -73, -90, -57, 40, 8, 23, 126, 82, 123, 105, 2, -102, -62, -78,
+ 111, 1, 49, 16, -77, 84, -4, 64, 27, 103, 102, 62, -101, -93,
+ -115, -103, -4, -35, -26, 36, 17, -78, 23, -106, 74, 4};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h
index 9b4b6854..72a5db25 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_input_gate_bias[7] = {26212, 18263, 849, 17710, 22563, -17022, 4727};
+const int32_t lstm_2_input_gate_bias[7] = {-475, 24405, 15813, -18608, 27695, -32747, 5436};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h
index c61ba1fb..868049cb 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h
index 11ad38e5..abc8cbc9 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_input_to_cell_eff_bias[7] = {-24495, 6800, -29846, 4212, -18843, -18518, -34196};
+const int32_t lstm_2_input_to_cell_eff_bias[7] = {57106, -25319, -15180, 23851, 27030, 48140, -10254};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h
index 466f6884..7b8dd0b6 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_cell_w_data.h
@@ -1,8 +1,8 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_input_to_cell_w[42] = {-37, -62, 88, 15, -84, 31, 14, -46, 97, 33, 86, -70, -112, -75,
- -45, -13, -22, 18, 26, -10, -117, -82, -45, 67, 106, 60, -127, -33,
- -118, -78, -90, 51, 10, -23, 43, 52, 23, -64, 79, 75, -121, -7};
+const int8_t lstm_2_input_to_cell_w[42] = {68, 111, 95, 35, -39, 48, -62, -74, -91, -28, -32, -13, -14, -92,
+ -54, -121, -9, 2, -70, 15, 66, -6, 26, -70, -52, -50, 67, 33,
+ 3, 57, 76, 127, 110, -70, 74, -49, 35, -42, 56, -22, 13, 62};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h
index b17afd3b..f4e40fad 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_input_to_forget_eff_bias[7] = {-19136, 14836, 19173, 16790, 12029, -6452, 14645};
+const int32_t lstm_2_input_to_forget_eff_bias[7] = {-34024, -60418, 47452, -16274, 15950, 23218, 30457};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h
index 91613c16..a072ca22 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_forget_w_data.h
@@ -1,8 +1,8 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_input_to_forget_w[42] = {127, -100, -44, 35, -24, -24, -26, 27, -21, -44, 116, 40, -43, 70,
- -101, -126, 78, 13, -31, -40, 121, -67, 108, 77, -56, 93, 10, 74,
- 56, 30, -104, 47, -14, 60, -120, -97, 19, -63, 84, 111, -53, 53};
+const int8_t lstm_2_input_to_forget_w[42] = {
+ -88, -23, 72, -119, -3, 95, -112, 13, -29, -78, 28, -119, 37, 45, 10, 67, 127, -36, 23, -116, 61,
+ -51, -54, 110, 55, 28, -116, 111, -109, -89, 106, 10, 2, -119, 89, -121, -33, 70, -105, 123, 45, 99};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h
index f4d218a2..c0d06764 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_input_to_input_eff_bias[7] = {50916, 22359, -33583, -2258, 24355, -7934, 11255};
+const int32_t lstm_2_input_to_input_eff_bias[7] = {-219, 25045, 20293, -41904, 43567, -33259, 29756};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h
index 2062c9ee..f0de95cf 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_input_w_data.h
@@ -1,8 +1,8 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_input_to_input_w[42] = {92, 7, -36, 85, 13, 32, -115, 84, -64, 124, 86, -83, 6, 127,
- -106, -116, -80, -100, -124, -2, 46, 54, -56, -74, -20, 83, -113, -59,
- 36, 87, -86, -31, -30, 5, 122, 91, -73, 96, 11, -7, 101, -77};
+const int8_t lstm_2_input_to_input_w[42] = {-105, -10, -45, 110, -23, 75, 17, -104, -35, 70, -66, 123, -113, 126,
+ 39, -127, -4, 114, 110, 0, -94, -24, -63, -111, -59, -13, 102, 116,
+ 7, -29, 40, 73, -113, -5, -85, 86, 124, 86, 125, 51, -92, -104};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h
index d6401012..442871c2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_input_to_output_eff_bias[7] = {-23260, -64953, 21616, 44956, -55885, -8659, -2372};
+const int32_t lstm_2_input_to_output_eff_bias[7] = {15991, -35409, -4890, -55904, -28433, -10231, -19317};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h
index aa7a60a4..bda0d883 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/input_to_output_w_data.h
@@ -1,8 +1,8 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_input_to_output_w[42] = {-13, -107, 95, -90, 57, -71, -122, -38, -121, -44, -68, -69, 109, -62,
- 126, -53, 43, -26, -101, 123, 50, 82, -46, 39, -81, 41, -58, -52,
- -101, 40, 13, -45, -15, -90, 19, -127, -6, -21, -38, 14, 25, -44};
+const int8_t lstm_2_input_to_output_w[42] = {-6, 11, 12, -127, -29, 19, -90, 49, 106, 71, -81, -124, 16, 1,
+ -121, -66, -79, -7, -102, -86, -40, 124, -88, 0, 82, -108, -7, 4,
+ 9, -75, 50, -120, -28, -63, 127, -56, -14, 41, 35, -98, -31, -105};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h
index 2c9a5e45..e9794eb7 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_output_gate_bias[7] = {-6748, -5817, 4080, 26140, -28877, 22701, 6588};
+const int32_t lstm_2_output_gate_bias[7] = {31351, -26577, 27878, -31328, -16273, 1289, 2699};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h
index 84c4b276..7ad90609 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h
index 9184badf..8b6351db 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_ref_data.h
@@ -1,12 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_output_ref[126] = {
- 51, 90, 58, 50, 43, 44, 16, -6, 83, 7, 78, 50, 40, -40, 6, 50, 37, 127, 38, 49, -77,
- 19, 45, 36, 127, 56, 31, -51, -1, 60, 19, 127, 46, 33, -55, 32, 69, 41, 127, 48, 45, -44,
- 47, 65, 20, 127, 50, 53, -3, 32, 48, 46, 127, 39, 43, -51, -6, 75, 6, 127, 59, 27, -33,
- 23, 72, -6, 117, 87, 40, 7, 16, 58, 28, 127, 52, 50, -52, 4, 72, 25, 127, 54, 46, -33,
- -15, 81, 46, 102, 61, 34, -44, 20, 66, 34, 127, 56, 39, -16, -28, 81, -11, 105, 59, 35, -73,
- -1, 56, 48, 127, 49, 36, -95, -19, 75, 27, 127, 56, 26, -90, 18, 71, 30, 121, 51, 50, -77};
+const int8_t lstm_2_output_ref[63] = {
+ 71, -80, -81, -19, 29, 0, -82, 101, -122, -107, -10, 65, 12, -83, 127, -39, -128, -7, 83, 33, -48,
+ 127, -92, -94, -13, 53, 40, -7, 115, -49, -128, -10, 86, 81, 6, 127, -76, -128, -6, 55, 44, 1,
+ 120, -128, -128, -3, 64, 30, -9, 127, -53, -128, 6, 80, 23, -2, 127, -63, -128, -10, 74, 61, 15};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h
index b09faec6..c83812e9 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/output_state_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-int8_t lstm_2_output_state[14] = {74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74};
+int8_t lstm_2_output_state[7] = {-24, -24, -24, -24, -24, -24, -24};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h
index 17722140..c6d005b4 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_bias_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h
index 6bc42d53..b01c0ffe 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/projection_weights_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h
index 48195152..9b9330bf 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_cell_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_2_recurrent_input_to_cell_w[49] = {111, -117, 121, -81, -85, -44, -107, 118, 83, 51, 69, 16, 110,
- 60, -36, 114, -55, -87, -60, -115, -8, -98, -91, -45, -94, -10,
- 49, -127, 17, 112, 126, -21, 86, -80, 82, 44, 90, -20, -38,
- 55, -67, 39, 125, 104, 80, 102, -110, -4, -5};
+const int8_t lstm_2_recurrent_input_to_cell_w[49] = {
+ -113, -40, -36, -21, -5, 30, 42, -66, -108, -52, -106, 39, -87, -24, -57, -52, 55,
+ 83, 16, 8, -92, 53, -119, 83, -60, 127, -90, 39, -87, 45, 91, -6, -33, 79,
+ -26, -82, 102, 97, -99, -97, 84, -55, 31, -42, -63, -94, 62, 83, 19};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h
index 8a54759a..f309730f 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_forget_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_2_recurrent_input_to_forget_w[49] = {
- 46, 72, -126, -19, -98, -127, -83, 6, 74, 110, 109, 69, -82, 71, 100, 84, -37,
- -88, 60, -86, -111, 14, -64, -95, 41, 87, -27, -67, -110, 59, -66, 88, -62, 61,
- 121, 60, -52, -17, 113, -102, 53, -103, 127, -56, -57, 45, 37, 125, 32};
+ -97, 94, -76, 43, 27, 4, 5, -65, 116, -99, -15, 38, 55, -124, -30, 52, 41,
+ -72, -31, 107, -109, 125, 112, -88, 36, 44, -121, 34, 21, -37, -20, -49, 68, -38,
+ 127, 1, -85, 60, -9, -117, -104, 34, -84, -48, 126, -33, 6, -105, 105};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h
index f5ff6047..a1426c08 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_input_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_2_recurrent_input_to_input_w[49] = {
- -41, 46, -126, -116, -34, -18, -107, -105, -6, -68, 56, 36, -2, 113, 76, -65, 67,
- -123, -125, 74, 92, 31, 116, -108, 114, -123, -60, -59, 29, -23, 4, -116, 114, -94,
- -79, 79, -13, -70, -99, -17, 67, 120, 127, 41, -71, 59, 80, 53, -111};
+ -5, -98, 108, 28, -6, 114, 11, -61, 6, -75, -115, -115, -21, -95, 43, 36, 72,
+ 33, 91, 10, 52, 58, 19, 47, -20, 17, -68, -25, 25, -87, -5, 125, 68, -42,
+ 75, -16, 28, -13, 23, 50, 111, -114, 58, -23, 127, 101, 43, -107, -28};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h
index 22ad7793..5d405188 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_input_to_output_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_2_recurrent_input_to_output_w[49] = {
- -38, 81, -44, -57, 126, -127, 50, -3, 54, 71, 74, -51, -9, -3, -122, 51, -26,
- 92, -18, -124, -91, 73, -68, 109, -33, 24, -55, -48, 124, -83, 12, -113, 68, 23,
- 32, 88, -115, 113, -109, 28, 61, -37, -31, -62, -83, 74, 73, -23, -123};
+ 67, -45, -7, 83, 73, 93, 116, 86, 40, 19, 114, 121, -34, 106, -76, -75, -99,
+ 24, -8, -71, 109, -4, -1, -125, -21, -125, 96, 33, 29, -113, -76, -51, -107, -28,
+ -120, -10, -10, -91, 40, -5, 117, -70, -94, -122, 127, 10, 102, 82, 29};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h
index 08190320..22be2cd4 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_cell_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_recurrent_to_cell_eff_bias[7] = {14948, -37518, 18278, 30784, -23828, -7622, -21608};
+const int32_t lstm_2_recurrent_to_cell_eff_bias[7] = {-3432, -9696, -936, 792, 1512, -1200, -96};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h
index 5398098e..1c9768d0 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_forget_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_recurrent_to_forget_eff_bias[7] = {24790, -26418, 5772, 8214, -6734, 3552, -18722};
+const int32_t lstm_2_recurrent_to_forget_eff_bias[7] = {0, -2256, -1008, 3408, 1728, -5280, -792};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h
index 56a21435..7d6bc02b 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_input_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_recurrent_to_input_eff_bias[7] = {29304, -1776, 296, 6586, 12210, -4958, -13172};
+const int32_t lstm_2_recurrent_to_input_eff_bias[7] = {3648, -11424, 8088, 672, 3816, 1656, 4104};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h
index 483dcb45..01aefc8d 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/recurrent_to_output_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_2_recurrent_to_output_eff_bias[7] = {666, -9842, 17612, -148, -4662, -2146, 12950};
+const int32_t lstm_2_recurrent_to_output_eff_bias[7] = {9120, 10848, -4704, -3528, -11184, -696, 3216};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h
index ecd25010..f1ab6ed2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_2/test_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#include "cell_gate_bias_data.h"
#include "cell_norm_coeff_data.h"
#include "cell_state_data.h"
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h
index cd215406..6b100d83 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_cell_gate_bias[3] = {-15651, -31126, 21566};
+const int32_t lstm_one_time_step_cell_gate_bias[3] = {322, -28994, -4077};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h
index 7a966059..1803b8ce 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h
index e81e6523..928b5312 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_state_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h
index 53071eec..784f00bf 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_forget_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h
index e24a3ddf..e63ba379 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_input_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h
index 05729680..bf8c40aa 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/cell_to_output_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h
index f18c7e93..d719d33b 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/config_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#define LSTM_ONE_TIME_STEP_BUFFER_SIZE 9
#define LSTM_ONE_TIME_STEP_INPUT_BATCHES 3
@@ -10,24 +10,29 @@
#define LSTM_ONE_TIME_STEP_TIME_MAJOR 0
#define LSTM_ONE_TIME_STEP_IN_ACTIVATION_MIN -32768
#define LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX 32767
-#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER 1081402240
-#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT -2
-#define LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER 2077292928
+#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER 2118812544
+#define LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT -3
+#define LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER 2143159040
#define LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT -3
-#define LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER 2082823168
-#define LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT -3
-#define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER 2119604096
+#define LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER 1079340672
+#define LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT -2
+#define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER 2144143488
#define LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT -3
-#define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER 1242172928
+#define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER 1158760960
#define LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT -3
-#define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER 1241717376
+#define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER 1197046528
#define LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT -3
-#define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER 1075447936
+#define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER 1244768512
#define LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT -3
-#define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER 1153133824
+#define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER 1216274304
#define LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT -3
-#define LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER 1833641157
+#define LSTM_ONE_TIME_STEP_FORGET_MULTIPLIER 1073741824
+#define LSTM_ONE_TIME_STEP_FORGET_SHIFT -14
+#define LSTM_ONE_TIME_STEP_INPUT_MULTIPLIER 1073741824
+#define LSTM_ONE_TIME_STEP_INPUT_SHIFT -13
+#define LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER 1828318324
#define LSTM_ONE_TIME_STEP_HIDDEN_SHIFT -21
-#define LSTM_ONE_TIME_STEP_HIDDEN_OFFSET 106
-#define LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET 106
-#define LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT -15
+#define LSTM_ONE_TIME_STEP_HIDDEN_OFFSET -16
+#define LSTM_ONE_TIME_STEP_DATA_OFFSET 128
+#define LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET -16
+#define LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT -16
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h
index 99818b75..3ec7b9ee 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_forget_gate_bias[3] = {25787, -14479, 16570};
+const int32_t lstm_one_time_step_forget_gate_bias[3] = {-19492, -713, -23531};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h
index 9bfd0b9a..27e6baf2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/forget_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h
index 1b349352..165ed208 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_data.h
@@ -1,9 +1,10 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_one_time_step_input[66] = {
- -75, -122, 117, 85, -108, -63, 57, -6, -128, -27, 51, -48, 35, -103, 6, -66, -35, 64, 82, 24, 29, -39,
- -53, 47, -74, 87, 124, 102, -93, -8, -128, 54, 45, -26, -107, -21, -60, -40, -45, 96, -50, 75, 36, -73,
- -93, -88, -21, -32, -112, -113, 73, -88, -101, 93, -70, 6, -36, 67, 86, 42, 66, -70, 10, 8, -63, -38};
+ 83, 52, 7, 91, -102, 29, 52, -108, -93, 73, -46, 27, 120, -117, -57, -86, -124,
+ -112, -36, 26, -127, -39, -97, -59, -111, 98, 8, 121, 30, 51, 18, 15, -92, -41,
+ -71, 42, -96, 66, -18, -127, -91, 59, -29, 71, 96, -44, 83, 63, 51, 55, -8,
+ -73, -63, -53, -66, -14, -113, 104, -118, -116, 50, -94, 94, -23, 88, -119};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h
index a931d2c1..0623f3b0 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_input_gate_bias[3] = {1420, -28495, -28358};
+const int32_t lstm_one_time_step_input_gate_bias[3] = {-33100, 3078, 31928};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h
index d84a6460..f5bb6ea9 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h
index 6a3057d8..70557c6c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_input_to_cell_eff_bias[3] = {-30243, 17130, 70846};
+const int32_t lstm_one_time_step_input_to_cell_eff_bias[3] = {16450, 17086, -36717};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h
index 39b5df84..fa886aee 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_cell_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_one_time_step_input_to_cell_w[66] = {
- 1, 110, 28, -87, 44, -118, -44, 88, -45, 85, -24, -33, 47, -85, 106, 10, -86, -121, -4, 86, -26, -46,
- -78, -109, 50, 91, 69, 110, -90, 86, 22, 102, -117, 79, 111, -114, 106, -127, 47, 106, 99, 30, -116, 20,
- 25, -19, 125, -91, 12, 82, -41, 24, 84, 108, -108, -102, -26, -117, 19, 59, 117, 127, -101, 22, 93, 93};
+ 42, -24, 36, 70, 39, 63, 95, -98, -113, 13, 43, 47, -107, -30, -117, -81, 116, 5, 122, -47, -37, 89,
+ -7, 74, -4, 114, -115, 119, 55, -24, -73, -97, 52, -26, 127, 32, 69, 76, 105, 14, -116, 27, -87, 45,
+ -104, -61, -110, -30, 44, -102, -93, 103, 17, 100, 14, -27, -100, 9, 13, 55, -50, 28, -1, -91, 19, 112};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h
index f1297854..d6d22ab5 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_input_to_forget_eff_bias[3] = {-2117, -71439, 29626};
+const int32_t lstm_one_time_step_input_to_forget_eff_bias[3] = {-49060, -20937, 76309};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h
index 8924c8c7..33fb5aec 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_forget_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_one_time_step_input_to_forget_w[66] = {
- -10, -66, -18, 87, -5, -12, 23, -75, -72, -13, 37, 18, 30, -28, -63, -20, 13, 117, 55, -103, -16, -97,
- -40, -124, 47, -78, -55, 104, -115, 86, -65, 39, -49, 56, 97, -62, -1, -11, 94, 67, -104, -110, -108, -113,
- 6, -71, 85, -84, 9, 48, 39, -65, -37, 9, -47, 123, -84, 127, -114, 90, 69, -4, -48, 0, -21, 72};
+ 34, -119, 10, 10, -70, -40, 37, 24, 124, -124, -64, 37, -107, 56, 84, -57, -33, 8, 10, 22, -98, 25,
+ -9, -47, -95, -21, 38, -124, -102, 12, 69, -48, -127, 96, -34, 39, 21, -104, 85, 58, 89, -66, 53, 59,
+ -34, 82, -62, 4, 44, 31, -22, -23, 34, 69, 0, 4, 114, 83, -72, 110, 91, 103, 94, -41, 117, 54};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h
index 9ffc4240..5a9e26f3 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_input_to_input_eff_bias[3] = {23436, -59215, -13894};
+const int32_t lstm_one_time_step_input_to_input_eff_bias[3] = {-12620, -12154, 952};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h
index 5182e6c3..6e827f5c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_input_w_data.h
@@ -1,9 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_one_time_step_input_to_input_w[66] = {
- 127, 42, 67, 72, -34, 103, 27, 64, -65, -23, -122, -105, -49, -63, 71, 30, -9, -5, 103, -16, 42, -85,
- 40, 37, -115, -19, 45, -49, 72, -26, 37, -122, 99, 16, 31, -94, -97, 92, -76, -111, -6, -82, -9, 97,
- 8, -23, 42, 108, 13, 13, 98, 50, 40, -52, -124, 7, 17, 117, 66, -84, -88, -97, -70, -4, 22, 54};
+ -51, -22, 74, 11, -47, -57, 65, -89, 98, 7, 33, -95, -94, 116, -74, 11, 126, -1, 107, 29, -33, 46,
+ 72, -127, 99, -53, 51, -15, 5, -66, 14, 14, -76, -105, -48, -114, 116, 38, -1, 96, 43, -94, -91, 123,
+ -103, -77, 64, -65, -94, -18, -88, -75, 25, 90, -125, 74, 52, -29, -7, 7, 56, 54, -83, -66, 75, 91};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h
index cd6d40dc..39cf6ba3 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_input_to_output_eff_bias[3] = {48450, 36258, -48261};
+const int32_t lstm_one_time_step_input_to_output_eff_bias[3] = {28514, 56537, 50815};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h
index b21f223b..e77c7456 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/input_to_output_w_data.h
@@ -1,10 +1,9 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
const int8_t lstm_one_time_step_input_to_output_w[66] = {
- 100, -56, 93, -52, 62, 127, 102, -91, 33, -120, -107, 81, -94, 104, -120, 6, -126,
- -47, 112, 92, 13, 55, 84, 127, -91, 6, 53, -19, -33, -120, 26, 124, 25, 78,
- 72, -122, -79, -62, 102, -6, 55, 125, -96, -53, -39, 1, -42, 107, 44, -56, -40,
- 41, -104, 16, 35, -107, -124, 118, 91, -111, -127, 67, 52, -66, -20, 75};
+ 74, 127, -51, 71, 113, -78, -47, -24, -11, -61, 74, -44, 65, -70, -50, -88, 102, -6, -78, 97, 78, -22,
+ -53, 100, -54, 54, 74, 98, 22, -106, -78, 100, 114, 123, -1, 0, 63, 75, 14, -51, 7, 19, -50, 16,
+ 127, 7, -44, -68, 65, 88, 28, 23, 33, 125, 94, 51, 24, -108, 93, 63, -22, -75, 52, -22, -104, -120};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h
index 2e4918e1..e7a83a6c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_gate_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_output_gate_bias[3] = {27074, 11170, -24069};
+const int32_t lstm_one_time_step_output_gate_bias[3] = {6626, -5671, 11135};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h
index 2a938b16..62ab428a 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_norm_coeff_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h
index 93ed5692..a2ae0559 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_ref_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_one_time_step_output_ref[9] = {-70, 122, 127, -63, 125, 127, 25, 110, 127};
+const int8_t lstm_one_time_step_output_ref[9] = {29, 98, -128, 38, 54, -70, 127, -63, -111};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h
index 8f4631e7..3d66fa1e 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/output_state_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-int8_t lstm_one_time_step_output_state[9] = {106, 106, 106, 106, 106, 106, 106, 106, 106};
+int8_t lstm_one_time_step_output_state[9] = {-16, -16, -16, -16, -16, -16, -16, -16, -16};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h
index 1c3140ec..7a90e907 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_bias_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h
index 62246b71..4266770c 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/projection_weights_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h
index 9b4a48e5..92e79347 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_cell_w_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_one_time_step_recurrent_input_to_cell_w[9] = {-42, -127, -4, -88, -56, 107, -109, -112, 11};
+const int8_t lstm_one_time_step_recurrent_input_to_cell_w[9] = {-91, -61, 124, 5, -127, -65, -103, -113, -53};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h
index 41d12737..7b8c939a 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_forget_w_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_one_time_step_recurrent_input_to_forget_w[9] = {-120, -120, -31, -90, 20, -54, 58, 127, -17};
+const int8_t lstm_one_time_step_recurrent_input_to_forget_w[9] = {-102, -67, -12, 7, -127, -96, 96, 16, -57};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h
index cfae7c8d..2623531a 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_input_w_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_one_time_step_recurrent_input_to_input_w[9] = {107, 126, 64, 47, 16, -71, 123, -127, -118};
+const int8_t lstm_one_time_step_recurrent_input_to_input_w[9] = {55, 103, 98, -127, -114, -93, 31, -36, 122};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h
index f1c194e9..72364ff6 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_input_to_output_w_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int8_t lstm_one_time_step_recurrent_input_to_output_w[9] = {-90, -45, 111, 100, -20, 66, 27, -127, 67};
+const int8_t lstm_one_time_step_recurrent_input_to_output_w[9] = {-127, -81, 94, -112, -25, -70, -68, 1, -16};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h
index 707b1408..0e17501e 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_cell_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_recurrent_to_cell_eff_bias[3] = {18338, 3922, 22260};
+const int32_t lstm_one_time_step_recurrent_to_cell_eff_bias[3] = {-448, -2992, -4304};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h
index 181eb477..76b468d1 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_forget_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_recurrent_to_forget_eff_bias[3] = {28726, 13144, -17808};
+const int32_t lstm_one_time_step_recurrent_to_forget_eff_bias[3] = {-2896, -3456, 880};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h
index 66278394..8121c87e 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_input_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_recurrent_to_input_eff_bias[3] = {-31482, 848, 12932};
+const int32_t lstm_one_time_step_recurrent_to_input_eff_bias[3] = {4096, -5344, 1872};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h
index 491b3bed..a0d73385 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/recurrent_to_output_eff_bias_data.h
@@ -1,6 +1,6 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#pragma once
#include
-const int32_t lstm_one_time_step_recurrent_to_output_eff_bias[3] = {2544, -15476, 3498};
+const int32_t lstm_one_time_step_recurrent_to_output_eff_bias[3] = {-1824, -3312, -1328};
diff --git a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h
index ecd25010..f1ab6ed2 100644
--- a/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h
+++ b/Tests/UnitTest/TestCases/TestData/lstm_one_time_step/test_data.h
@@ -1,5 +1,5 @@
-// Generated by generate_test_data.py using tensorflow version 2.10.0 (Keras version 2.10.0).
-// Interpreter from tflite_runtime version 2.11.0 and revision 0.6.0-134012-g31cfa135ac4.
+// Generated by test_settings.py using tensorflow version 2.15.0 (Keras version 2.15.0).
+// Interpreter from tflite_micro version 0.dev20231214171620-g47f77ab and revision None.
#include "cell_gate_bias_data.h"
#include "cell_norm_coeff_data.h"
#include "cell_state_data.h"
diff --git a/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c b/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c
index f17a7f0e..ca21593e 100644
--- a/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c
+++ b/Tests/UnitTest/TestCases/test_arm_ds_cnn_l_s8/test_arm_ds_cnn_l_s8.c
@@ -473,7 +473,7 @@ void ds_cnn_l_s8_inference(void)
bias_dims.c = in_out_dim_1.c;
#if defined(ARM_MATH_MVEI)
- arm_vector_sum_s8(ctx.buf, conv_filter_dims.n, in_out_dim_1.c, ds_cnn_l_layer_14_fully_connected_weights);
+ arm_vector_sum_s8(ctx.buf, conv_filter_dims.n, in_out_dim_1.c, ds_cnn_l_layer_14_fully_connected_weights, 1, NULL);
#endif
status |= arm_fully_connected_s8(&ctx,
diff --git a/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c b/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c
index c5e0713c..be5be41d 100644
--- a/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c
+++ b/Tests/UnitTest/TestCases/test_arm_fully_connected_s8/test_arm_fully_connected_s8.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -70,7 +70,7 @@ void fully_connected_arm_fully_connected_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *buf = ctx.buf;
- TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data));
+ TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
#endif
arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
@@ -133,7 +133,7 @@ void fully_connected_mve_0_arm_fully_connected_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *buf = ctx.buf;
- TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data));
+ TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
#endif
arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
@@ -195,7 +195,7 @@ void fully_connected_mve_1_arm_fully_connected_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *buf = ctx.buf;
- TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data));
+ TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
#endif
arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
@@ -268,7 +268,7 @@ void fully_connected_null_bias_0_arm_fully_connected_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *buf = ctx.buf;
- TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data));
+ TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
#endif
arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
@@ -330,7 +330,7 @@ void fully_connected_out_activation_arm_fully_connected_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *buf = ctx.buf;
- TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data));
+ TEST_ASSERT_EQUAL(expected, arm_vector_sum_s8(buf, filter_dims.n, output_dims.c, kernel_data, 1, NULL));
#endif
arm_cmsis_nn_status result = arm_fully_connected_s8(&ctx,
diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c
deleted file mode 100644
index d90da214..00000000
--- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/test_arm_lstm_unidirectional_s16_s8.c
+++ /dev/null
@@ -1,328 +0,0 @@
-/*
- * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates
- *
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the License); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include
-#include
-
-#include
-#include
-
-#include "../TestData/lstm_1/test_data.h"
-#include "../TestData/lstm_2/test_data.h"
-#include "../TestData/lstm_one_time_step/test_data.h"
-#include "../Utils/validate.h"
-
-#if (LSTM_2_BUFFER_SIZE < LSTM_1_BUFFER_SIZE) || (LSTM_2_BUFFER_SIZE < LSTM_ONE_TIME_STEP_BUFFER_SIZE)
- #error "Test buffers too small."
-#endif
-
-// Update the buffer size if adding a unit test with larger buffer.
-#define LARGEST_BUFFER_SIZE LSTM_2_BUFFER_SIZE
-
-int16_t buffer0[LARGEST_BUFFER_SIZE];
-int16_t buffer1[LARGEST_BUFFER_SIZE];
-int16_t buffer2[LARGEST_BUFFER_SIZE];
-int16_t buffer3[LARGEST_BUFFER_SIZE];
-
-void lstm_1_arm_lstm_unidirectional_s16_s8(void)
-{
- const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
- const bool time_major = (bool)LSTM_1_TIME_MAJOR;
-
- int8_t output[LSTM_1_DST_SIZE] = {0};
-
- cmsis_nn_lstm_context scratch_buffers = {};
- cmsis_nn_lstm_dims lstm_dims = {};
- cmsis_nn_lstm_params lstm = {};
-
- scratch_buffers.input_gate = buffer0;
- scratch_buffers.forget_gate = buffer1;
- scratch_buffers.cell_gate = buffer2;
- scratch_buffers.output_gate = buffer3;
-
- lstm_dims.num_batches = LSTM_1_INPUT_BATCHES;
- lstm_dims.num_inputs = LSTM_1_NUMBER_INPUTS;
- lstm_dims.max_time = LSTM_1_TIME_STEPS;
- lstm_dims.num_outputs = LSTM_1_NUMBER_UNITS;
-
- lstm.time_major = time_major;
- lstm.input_to_input_scaling.multiplier = LSTM_1_IN_TO_INPUT_MULTIPLIER;
- lstm.input_to_input_scaling.shift = LSTM_1_IN_TO_INPUT_SHIFT;
- lstm.input_to_forget_scaling.multiplier = LSTM_1_IN_TO_FORGET_MULTIPLIER;
- lstm.input_to_forget_scaling.shift = LSTM_1_IN_TO_FORGET_SHIFT;
- lstm.input_to_cell_scaling.multiplier = LSTM_1_IN_TO_CELL_MULTIPLIER;
- lstm.input_to_cell_scaling.shift = LSTM_1_IN_TO_CELL_SHIFT;
- lstm.input_to_output_scaling.multiplier = LSTM_1_IN_TO_OUTPUT_MULTIPLIER;
- lstm.input_to_output_scaling.shift = LSTM_1_IN_TO_OUTPUT_SHIFT;
-
- lstm.recurrent_to_input_scaling.multiplier = LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER;
- lstm.recurrent_to_input_scaling.shift = LSTM_1_RECURRENT_TO_INPUT_SHIFT;
- lstm.recurrent_to_cell_scaling.multiplier = LSTM_1_RECURRENT_TO_CELL_MULTIPLIER;
- lstm.recurrent_to_cell_scaling.shift = LSTM_1_RECURRENT_TO_CELL_SHIFT;
- lstm.recurrent_to_forget_scaling.multiplier = LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER;
- lstm.recurrent_to_forget_scaling.shift = LSTM_1_RECURRENT_TO_FORGET_SHIFT;
- lstm.recurrent_to_output_scaling.multiplier = LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER;
- lstm.recurrent_to_output_scaling.shift = LSTM_1_RECURRENT_TO_OUTPUT_SHIFT;
-
- lstm.i2i_effective_bias = lstm_1_input_to_input_eff_bias;
- lstm.i2f_effective_bias = lstm_1_input_to_forget_eff_bias;
- lstm.i2c_effective_bias = lstm_1_input_to_cell_eff_bias;
- lstm.i2o_effective_bias = lstm_1_input_to_output_eff_bias;
-
- lstm.r2i_effective_bias = lstm_1_recurrent_to_input_eff_bias;
- lstm.r2f_effective_bias = lstm_1_recurrent_to_forget_eff_bias;
- lstm.r2c_effective_bias = lstm_1_recurrent_to_cell_eff_bias;
- lstm.r2o_effective_bias = lstm_1_recurrent_to_output_eff_bias;
-
- lstm.input_gate_bias = lstm_1_input_gate_bias;
- lstm.forget_gate_bias = lstm_1_forget_gate_bias;
- lstm.cell_gate_bias = lstm_1_cell_gate_bias;
- lstm.output_gate_bias = lstm_1_output_gate_bias;
-
- lstm.activation.min = LSTM_1_IN_ACTIVATION_MIN;
- lstm.activation.max = LSTM_1_IN_ACTIVATION_MAX;
-
- lstm.hidden_scaling.multiplier = LSTM_1_HIDDEN_MULTIPLIER;
- lstm.hidden_scaling.shift = LSTM_1_HIDDEN_SHIFT;
-
- lstm.hidden_offset = LSTM_1_HIDDEN_OFFSET;
-
- lstm.cell_state_shift = LSTM_1_CELL_STATE_SHIFT;
- lstm.output_state_offset = LSTM_1_OUTPUT_STATE_OFFSET;
-
- const int8_t *input_data = lstm_1_input;
- const int8_t *output_ref = lstm_1_output_ref;
- const int32_t output_ref_size = LSTM_1_DST_SIZE;
-
- arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers,
- input_data,
- &lstm_dims,
- lstm_1_input_to_input_w,
- lstm_1_input_to_forget_w,
- lstm_1_input_to_cell_w,
- lstm_1_input_to_output_w,
- lstm_1_recurrent_input_to_input_w,
- lstm_1_recurrent_input_to_forget_w,
- lstm_1_recurrent_input_to_cell_w,
- lstm_1_recurrent_input_to_output_w,
- lstm_1_cell_to_input,
- lstm_1_cell_to_forget,
- lstm_1_cell_to_output,
- lstm_1_projection_weights,
- &lstm,
- lstm_1_output_state,
- lstm_1_cell_state,
- output);
-
- TEST_ASSERT_EQUAL(expected, result);
- TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
-}
-
-void lstm_2_arm_lstm_unidirectional_s16_s8(void)
-{
- const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
- const bool time_major = (bool)LSTM_2_TIME_MAJOR;
-
- int8_t output[LSTM_2_DST_SIZE] = {0};
-
- cmsis_nn_lstm_context scratch_buffers = {};
- cmsis_nn_lstm_dims lstm_dims = {};
- cmsis_nn_lstm_params lstm = {};
-
- scratch_buffers.input_gate = buffer0;
- scratch_buffers.forget_gate = buffer1;
- scratch_buffers.cell_gate = buffer2;
- scratch_buffers.output_gate = buffer3;
-
- lstm_dims.num_batches = LSTM_2_INPUT_BATCHES;
- lstm_dims.num_inputs = LSTM_2_NUMBER_INPUTS;
- lstm_dims.max_time = LSTM_2_TIME_STEPS;
- lstm_dims.num_outputs = LSTM_2_NUMBER_UNITS;
-
- lstm.time_major = time_major;
- lstm.input_to_input_scaling.multiplier = LSTM_2_IN_TO_INPUT_MULTIPLIER;
- lstm.input_to_input_scaling.shift = LSTM_2_IN_TO_INPUT_SHIFT;
- lstm.input_to_forget_scaling.multiplier = LSTM_2_IN_TO_FORGET_MULTIPLIER;
- lstm.input_to_forget_scaling.shift = LSTM_2_IN_TO_FORGET_SHIFT;
- lstm.input_to_cell_scaling.multiplier = LSTM_2_IN_TO_CELL_MULTIPLIER;
- lstm.input_to_cell_scaling.shift = LSTM_2_IN_TO_CELL_SHIFT;
- lstm.input_to_output_scaling.multiplier = LSTM_2_IN_TO_OUTPUT_MULTIPLIER;
- lstm.input_to_output_scaling.shift = LSTM_2_IN_TO_OUTPUT_SHIFT;
-
- lstm.recurrent_to_input_scaling.multiplier = LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER;
- lstm.recurrent_to_input_scaling.shift = LSTM_2_RECURRENT_TO_INPUT_SHIFT;
- lstm.recurrent_to_cell_scaling.multiplier = LSTM_2_RECURRENT_TO_CELL_MULTIPLIER;
- lstm.recurrent_to_cell_scaling.shift = LSTM_2_RECURRENT_TO_CELL_SHIFT;
- lstm.recurrent_to_forget_scaling.multiplier = LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER;
- lstm.recurrent_to_forget_scaling.shift = LSTM_2_RECURRENT_TO_FORGET_SHIFT;
- lstm.recurrent_to_output_scaling.multiplier = LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER;
- lstm.recurrent_to_output_scaling.shift = LSTM_2_RECURRENT_TO_OUTPUT_SHIFT;
-
- lstm.i2i_effective_bias = lstm_2_input_to_input_eff_bias;
- lstm.i2f_effective_bias = lstm_2_input_to_forget_eff_bias;
- lstm.i2c_effective_bias = lstm_2_input_to_cell_eff_bias;
- lstm.i2o_effective_bias = lstm_2_input_to_output_eff_bias;
-
- lstm.r2i_effective_bias = lstm_2_recurrent_to_input_eff_bias;
- lstm.r2f_effective_bias = lstm_2_recurrent_to_forget_eff_bias;
- lstm.r2c_effective_bias = lstm_2_recurrent_to_cell_eff_bias;
- lstm.r2o_effective_bias = lstm_2_recurrent_to_output_eff_bias;
-
- lstm.input_gate_bias = lstm_2_input_gate_bias;
- lstm.forget_gate_bias = lstm_2_forget_gate_bias;
- lstm.cell_gate_bias = lstm_2_cell_gate_bias;
- lstm.output_gate_bias = lstm_2_output_gate_bias;
-
- lstm.activation.min = LSTM_2_IN_ACTIVATION_MIN;
- lstm.activation.max = LSTM_2_IN_ACTIVATION_MAX;
-
- lstm.hidden_scaling.multiplier = LSTM_2_HIDDEN_MULTIPLIER;
- lstm.hidden_scaling.shift = LSTM_2_HIDDEN_SHIFT;
-
- lstm.hidden_offset = LSTM_2_HIDDEN_OFFSET;
-
- lstm.cell_state_shift = LSTM_2_CELL_STATE_SHIFT;
- lstm.output_state_offset = LSTM_2_OUTPUT_STATE_OFFSET;
-
- const int8_t *input_data = lstm_2_input;
- const int8_t *output_ref = lstm_2_output_ref;
- const int32_t output_ref_size = LSTM_2_DST_SIZE;
-
- arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers,
- input_data,
- &lstm_dims,
- lstm_2_input_to_input_w,
- lstm_2_input_to_forget_w,
- lstm_2_input_to_cell_w,
- lstm_2_input_to_output_w,
- lstm_2_recurrent_input_to_input_w,
- lstm_2_recurrent_input_to_forget_w,
- lstm_2_recurrent_input_to_cell_w,
- lstm_2_recurrent_input_to_output_w,
- lstm_2_cell_to_input,
- lstm_2_cell_to_forget,
- lstm_2_cell_to_output,
- lstm_2_projection_weights,
- &lstm,
- lstm_2_output_state,
- lstm_2_cell_state,
- output);
-
- TEST_ASSERT_EQUAL(expected, result);
- TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
-}
-
-void lstm_one_time_step_arm_lstm_unidirectional_s16_s8(void)
-{
- const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
- const bool time_major = (bool)LSTM_ONE_TIME_STEP_TIME_MAJOR;
-
- int8_t output[LSTM_ONE_TIME_STEP_DST_SIZE] = {0};
-
- int16_t cell_state[LSTM_ONE_TIME_STEP_DST_SIZE];
- int8_t output_state[LSTM_ONE_TIME_STEP_DST_SIZE];
-
- memcpy(output_state, lstm_one_time_step_output_state, LSTM_ONE_TIME_STEP_DST_SIZE * sizeof(int8_t));
- memcpy(cell_state, lstm_one_time_step_cell_state, LSTM_ONE_TIME_STEP_DST_SIZE * sizeof(int16_t));
-
- cmsis_nn_lstm_context scratch_buffers = {};
- cmsis_nn_lstm_dims lstm_dims = {};
- cmsis_nn_lstm_params lstm = {};
-
- scratch_buffers.input_gate = buffer0;
- scratch_buffers.forget_gate = buffer1;
- scratch_buffers.cell_gate = buffer2;
- scratch_buffers.output_gate = buffer3;
-
- lstm_dims.num_batches = LSTM_ONE_TIME_STEP_INPUT_BATCHES;
- lstm_dims.num_inputs = LSTM_ONE_TIME_STEP_NUMBER_INPUTS;
- lstm_dims.max_time = LSTM_ONE_TIME_STEP_TIME_STEPS;
- lstm_dims.num_outputs = LSTM_ONE_TIME_STEP_NUMBER_UNITS;
-
- lstm.time_major = time_major;
- lstm.input_to_input_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER;
- lstm.input_to_input_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT;
- lstm.input_to_forget_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER;
- lstm.input_to_forget_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT;
- lstm.input_to_cell_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER;
- lstm.input_to_cell_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT;
- lstm.input_to_output_scaling.multiplier = LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER;
- lstm.input_to_output_scaling.shift = LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT;
-
- lstm.recurrent_to_input_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER;
- lstm.recurrent_to_input_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT;
- lstm.recurrent_to_cell_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER;
- lstm.recurrent_to_cell_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT;
- lstm.recurrent_to_forget_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER;
- lstm.recurrent_to_forget_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT;
- lstm.recurrent_to_output_scaling.multiplier = LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER;
- lstm.recurrent_to_output_scaling.shift = LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT;
-
- lstm.i2i_effective_bias = lstm_one_time_step_input_to_input_eff_bias;
- lstm.i2f_effective_bias = lstm_one_time_step_input_to_forget_eff_bias;
- lstm.i2c_effective_bias = lstm_one_time_step_input_to_cell_eff_bias;
- lstm.i2o_effective_bias = lstm_one_time_step_input_to_output_eff_bias;
-
- lstm.r2i_effective_bias = lstm_one_time_step_recurrent_to_input_eff_bias;
- lstm.r2f_effective_bias = lstm_one_time_step_recurrent_to_forget_eff_bias;
- lstm.r2c_effective_bias = lstm_one_time_step_recurrent_to_cell_eff_bias;
- lstm.r2o_effective_bias = lstm_one_time_step_recurrent_to_output_eff_bias;
-
- lstm.input_gate_bias = lstm_one_time_step_input_gate_bias;
- lstm.forget_gate_bias = lstm_one_time_step_forget_gate_bias;
- lstm.cell_gate_bias = lstm_one_time_step_cell_gate_bias;
- lstm.output_gate_bias = lstm_one_time_step_output_gate_bias;
-
- lstm.activation.min = LSTM_ONE_TIME_STEP_IN_ACTIVATION_MIN;
- lstm.activation.max = LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX;
-
- lstm.hidden_scaling.multiplier = LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER;
- lstm.hidden_scaling.shift = LSTM_ONE_TIME_STEP_HIDDEN_SHIFT;
-
- lstm.hidden_offset = LSTM_ONE_TIME_STEP_HIDDEN_OFFSET;
-
- lstm.cell_state_shift = LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT;
- lstm.output_state_offset = LSTM_ONE_TIME_STEP_OUTPUT_STATE_OFFSET;
-
- const int8_t *input_data = lstm_one_time_step_input;
- const int8_t *output_ref = lstm_one_time_step_output_ref;
- const int32_t output_ref_size = LSTM_ONE_TIME_STEP_DST_SIZE;
-
- arm_cmsis_nn_status result = arm_lstm_unidirectional_s16_s8(&scratch_buffers,
- input_data,
- &lstm_dims,
- lstm_one_time_step_input_to_input_w,
- lstm_one_time_step_input_to_forget_w,
- lstm_one_time_step_input_to_cell_w,
- lstm_one_time_step_input_to_output_w,
- lstm_one_time_step_recurrent_input_to_input_w,
- lstm_one_time_step_recurrent_input_to_forget_w,
- lstm_one_time_step_recurrent_input_to_cell_w,
- lstm_one_time_step_recurrent_input_to_output_w,
- lstm_one_time_step_cell_to_input,
- lstm_one_time_step_cell_to_forget,
- lstm_one_time_step_cell_to_output,
- lstm_one_time_step_projection_weights,
- &lstm,
- lstm_one_time_step_output_state,
- lstm_one_time_step_cell_state,
- output);
-
- TEST_ASSERT_EQUAL(expected, result);
- TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
-}
diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt
similarity index 67%
rename from Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt
rename to Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt
index e51e4e9f..ea73e4ff 100644
--- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/CMakeLists.txt
+++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates
+# SPDX-FileCopyrightText: Copyright 2010-2022, 2024 Arm Limited and/or its affiliates
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -16,8 +16,8 @@
# limitations under the License.
#
-add_cmsis_nn_unit_test_executable(test_arm_lstm_unidirectional_s16_s8)
+add_cmsis_nn_unit_test_executable(test_arm_lstm_unidirectional_s8)
-target_sources(test_arm_lstm_unidirectional_s16_s8 PRIVATE
- Unity/unity_test_arm_lstm_unidirectional_s16_s8.c
- Unity/TestRunner/unity_test_arm_lstm_unidirectional_s16_s8_runner.c)
+target_sources(test_arm_lstm_unidirectional_s8 PRIVATE
+ Unity/unity_test_arm_lstm_unidirectional_s8.c
+ Unity/TestRunner/unity_test_arm_lstm_unidirectional_s8_runner.c)
diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c
similarity index 69%
rename from Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c
rename to Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c
index 05276f38..28dadf9b 100644
--- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16_s8/Unity/unity_test_arm_lstm_unidirectional_s16_s8.c
+++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2022, 2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -22,7 +22,7 @@
#include
#include
-#include "../test_arm_lstm_unidirectional_s16_s8.c"
+#include "../test_arm_lstm_unidirectional_s8.c"
#include "unity.h"
#ifdef USING_FVP_CORSTONE_300
@@ -44,9 +44,6 @@ void setUp(void)
*/
void tearDown(void) {}
-void test_lstm_1_arm_lstm_unidirectional_s16_s8(void) { lstm_1_arm_lstm_unidirectional_s16_s8(); }
-void test_lstm_2_arm_lstm_unidirectional_s16_s8(void) { lstm_2_arm_lstm_unidirectional_s16_s8(); }
-void test_lstm_one_time_step_arm_lstm_unidirectional_s16_s8(void)
-{
- lstm_one_time_step_arm_lstm_unidirectional_s16_s8();
-}
+void test_lstm_1_arm_lstm_unidirectional_s8(void) { lstm_1_arm_lstm_unidirectional_s8(); }
+void test_lstm_2_arm_lstm_unidirectional_s8(void) { lstm_2_arm_lstm_unidirectional_s8(); }
+void test_lstm_one_time_step_arm_lstm_unidirectional_s8(void) { lstm_one_time_step_arm_lstm_unidirectional_s8(); }
diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c
new file mode 100644
index 00000000..1d132686
--- /dev/null
+++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c
@@ -0,0 +1,495 @@
+/*
+ * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+#include
+
+#include "../TestData/lstm_1/test_data.h"
+#include "../TestData/lstm_2/test_data.h"
+#include "../TestData/lstm_one_time_step/test_data.h"
+#include "../Utils/validate.h"
+
+#if (LSTM_2_BUFFER_SIZE > LSTM_1_BUFFER_SIZE) || (LSTM_1_BUFFER_SIZE < LSTM_ONE_TIME_STEP_BUFFER_SIZE)
+ #error "Test buffers too small."
+#endif
+
+// Update the buffer size if adding a unit test with larger buffer.
+#define LARGEST_BUFFER_SIZE LSTM_1_BUFFER_SIZE
+
+int16_t buffer1[LARGEST_BUFFER_SIZE];
+int16_t buffer2[LARGEST_BUFFER_SIZE];
+int16_t buffer3[LARGEST_BUFFER_SIZE];
+
+void lstm_1_arm_lstm_unidirectional_s8(void)
+{
+ int8_t output[LSTM_1_DST_SIZE] = {0};
+ const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
+ const int8_t *output_ref = lstm_1_output_ref;
+ const int32_t output_ref_size = LSTM_1_DST_SIZE;
+
+ // Calculate kernel sums if using MVE-extension
+ int32_t input_data_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t forget_data_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t cell_data_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t output_data_kernel_sum[LSTM_1_NUMBER_UNITS];
+
+ int32_t input_hidden_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t forget_hidden_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t cell_hidden_kernel_sum[LSTM_1_NUMBER_UNITS];
+ int32_t output_hidden_kernel_sum[LSTM_1_NUMBER_UNITS];
+
+ int32_t size_data = LSTM_1_NUMBER_INPUTS;
+ int32_t size_hidden = LSTM_1_NUMBER_UNITS;
+
+ arm_vector_sum_s8(&input_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_1_input_to_input_w[0],
+ LSTM_1_DATA_OFFSET,
+ &lstm_1_input_gate_bias[0]);
+ arm_vector_sum_s8(&forget_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_1_input_to_forget_w[0],
+ LSTM_1_DATA_OFFSET,
+ &lstm_1_forget_gate_bias[0]);
+ arm_vector_sum_s8(&cell_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_1_input_to_cell_w[0],
+ LSTM_1_DATA_OFFSET,
+ &lstm_1_cell_gate_bias[0]);
+ arm_vector_sum_s8(&output_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_1_input_to_output_w[0],
+ LSTM_1_DATA_OFFSET,
+ &lstm_1_output_gate_bias[0]);
+
+ arm_vector_sum_s8(&input_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_1_recurrent_input_to_input_w[0],
+ -LSTM_1_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&forget_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_1_recurrent_input_to_forget_w[0],
+ -LSTM_1_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&cell_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_1_recurrent_input_to_cell_w[0],
+ -LSTM_1_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&output_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_1_recurrent_input_to_output_w[0],
+ -LSTM_1_HIDDEN_OFFSET,
+ NULL);
+
+ // INPUT GATE
+ const cmsis_nn_lstm_gate gate_input = {LSTM_1_IN_TO_INPUT_MULTIPLIER,
+ LSTM_1_IN_TO_INPUT_SHIFT,
+ &lstm_1_input_to_input_w[0],
+ &input_data_kernel_sum[0],
+ LSTM_1_RECURRENT_TO_INPUT_MULTIPLIER,
+ LSTM_1_RECURRENT_TO_INPUT_SHIFT,
+ &lstm_1_recurrent_input_to_input_w[0],
+ &input_hidden_kernel_sum[0],
+ &lstm_1_input_gate_bias[0],
+ ARM_SIGMOID};
+
+ // FORGET GATE
+ const cmsis_nn_lstm_gate gate_forget = {LSTM_1_IN_TO_FORGET_MULTIPLIER,
+ LSTM_1_IN_TO_FORGET_SHIFT,
+ &lstm_1_input_to_forget_w[0],
+ &forget_data_kernel_sum[0],
+ LSTM_1_RECURRENT_TO_FORGET_MULTIPLIER,
+ LSTM_1_RECURRENT_TO_FORGET_SHIFT,
+ &lstm_1_recurrent_input_to_forget_w[0],
+ &forget_hidden_kernel_sum[0],
+ &lstm_1_forget_gate_bias[0],
+ ARM_SIGMOID};
+
+ // CELL GATE
+ const cmsis_nn_lstm_gate gate_cell = {LSTM_1_IN_TO_CELL_MULTIPLIER,
+ LSTM_1_IN_TO_CELL_SHIFT,
+ &lstm_1_input_to_cell_w[0],
+ &cell_data_kernel_sum[0],
+ LSTM_1_RECURRENT_TO_CELL_MULTIPLIER,
+ LSTM_1_RECURRENT_TO_CELL_SHIFT,
+ &lstm_1_recurrent_input_to_cell_w[0],
+ &cell_hidden_kernel_sum[0],
+ &lstm_1_cell_gate_bias[0],
+ ARM_TANH};
+
+ // OUTPUT GATE
+ const cmsis_nn_lstm_gate gate_output = {LSTM_1_IN_TO_OUTPUT_MULTIPLIER,
+ LSTM_1_IN_TO_OUTPUT_SHIFT,
+ &lstm_1_input_to_output_w[0],
+ &output_data_kernel_sum[0],
+ LSTM_1_RECURRENT_TO_OUTPUT_MULTIPLIER,
+ LSTM_1_RECURRENT_TO_OUTPUT_SHIFT,
+ &lstm_1_recurrent_input_to_output_w[0],
+ &output_hidden_kernel_sum[0],
+ &lstm_1_output_gate_bias[0],
+ ARM_SIGMOID};
+
+ // LSTM DATA
+ const cmsis_nn_lstm_params params = {LSTM_1_TIME_MAJOR,
+ LSTM_1_INPUT_BATCHES,
+ LSTM_1_TIME_STEPS,
+ LSTM_1_NUMBER_INPUTS,
+ LSTM_1_NUMBER_UNITS,
+ LSTM_1_DATA_OFFSET,
+ LSTM_1_FORGET_MULTIPLIER,
+ LSTM_1_FORGET_SHIFT,
+ LSTM_1_INPUT_MULTIPLIER,
+ LSTM_1_INPUT_SHIFT,
+ LSTM_1_IN_ACTIVATION_MAX,
+ LSTM_1_CELL_STATE_SHIFT,
+ LSTM_1_HIDDEN_MULTIPLIER,
+ LSTM_1_HIDDEN_SHIFT,
+ LSTM_1_HIDDEN_OFFSET,
+ gate_forget,
+ gate_input,
+ gate_cell,
+ gate_output};
+
+ // BUFFERS
+ cmsis_nn_lstm_context buffers;
+ buffers.temp1 = buffer1;
+ buffers.temp2 = buffer2;
+ buffers.cell_state = buffer3;
+
+ arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_1_input, output, ¶ms, &buffers);
+
+ TEST_ASSERT_EQUAL(expected, result);
+ TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
+}
+
+void lstm_2_arm_lstm_unidirectional_s8(void)
+{
+ const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
+ const int8_t *output_ref = lstm_2_output_ref;
+ const int32_t output_ref_size = LSTM_2_DST_SIZE;
+
+ // Calculate kernel sums if using MVE-extension
+ int32_t input_data_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t forget_data_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t cell_data_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t output_data_kernel_sum[LSTM_2_NUMBER_UNITS];
+
+ int32_t input_hidden_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t forget_hidden_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t cell_hidden_kernel_sum[LSTM_2_NUMBER_UNITS];
+ int32_t output_hidden_kernel_sum[LSTM_2_NUMBER_UNITS];
+
+ int32_t size_data = LSTM_2_NUMBER_INPUTS;
+ int32_t size_hidden = LSTM_2_NUMBER_UNITS;
+
+ arm_vector_sum_s8(&input_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_2_input_to_input_w[0],
+ LSTM_2_DATA_OFFSET,
+ &lstm_2_input_gate_bias[0]);
+ arm_vector_sum_s8(&forget_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_2_input_to_forget_w[0],
+ LSTM_2_DATA_OFFSET,
+ &lstm_2_forget_gate_bias[0]);
+ arm_vector_sum_s8(&cell_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_2_input_to_cell_w[0],
+ LSTM_2_DATA_OFFSET,
+ &lstm_2_cell_gate_bias[0]);
+ arm_vector_sum_s8(&output_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_2_input_to_output_w[0],
+ LSTM_2_DATA_OFFSET,
+ &lstm_2_output_gate_bias[0]);
+
+ arm_vector_sum_s8(&input_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_2_recurrent_input_to_input_w[0],
+ -LSTM_2_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&forget_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_2_recurrent_input_to_forget_w[0],
+ -LSTM_2_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&cell_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_2_recurrent_input_to_cell_w[0],
+ -LSTM_2_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&output_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_2_recurrent_input_to_output_w[0],
+ -LSTM_2_HIDDEN_OFFSET,
+ NULL);
+
+ // INPUT GATE
+ const cmsis_nn_lstm_gate gate_input = {LSTM_2_IN_TO_INPUT_MULTIPLIER,
+ LSTM_2_IN_TO_INPUT_SHIFT,
+ &lstm_2_input_to_input_w[0],
+ &input_data_kernel_sum[0],
+ LSTM_2_RECURRENT_TO_INPUT_MULTIPLIER,
+ LSTM_2_RECURRENT_TO_INPUT_SHIFT,
+ &lstm_2_recurrent_input_to_input_w[0],
+ &input_hidden_kernel_sum[0],
+ &lstm_2_input_gate_bias[0],
+ ARM_SIGMOID};
+
+ // FORGET GATE
+ const cmsis_nn_lstm_gate gate_forget = {LSTM_2_IN_TO_FORGET_MULTIPLIER,
+ LSTM_2_IN_TO_FORGET_SHIFT,
+ &lstm_2_input_to_forget_w[0],
+ &forget_data_kernel_sum[0],
+ LSTM_2_RECURRENT_TO_FORGET_MULTIPLIER,
+ LSTM_2_RECURRENT_TO_FORGET_SHIFT,
+ &lstm_2_recurrent_input_to_forget_w[0],
+ &forget_hidden_kernel_sum[0],
+ &lstm_2_forget_gate_bias[0],
+ ARM_SIGMOID};
+
+ // CELL GATE
+ const cmsis_nn_lstm_gate gate_cell = {LSTM_2_IN_TO_CELL_MULTIPLIER,
+ LSTM_2_IN_TO_CELL_SHIFT,
+ &lstm_2_input_to_cell_w[0],
+ &cell_data_kernel_sum[0],
+ LSTM_2_RECURRENT_TO_CELL_MULTIPLIER,
+ LSTM_2_RECURRENT_TO_CELL_SHIFT,
+ &lstm_2_recurrent_input_to_cell_w[0],
+ &cell_hidden_kernel_sum[0],
+ &lstm_2_cell_gate_bias[0],
+ ARM_TANH};
+
+ // OUTPUT GATE
+ const cmsis_nn_lstm_gate gate_output = {LSTM_2_IN_TO_OUTPUT_MULTIPLIER,
+ LSTM_2_IN_TO_OUTPUT_SHIFT,
+ &lstm_2_input_to_output_w[0],
+ &output_data_kernel_sum[0],
+ LSTM_2_RECURRENT_TO_OUTPUT_MULTIPLIER,
+ LSTM_2_RECURRENT_TO_OUTPUT_SHIFT,
+ &lstm_2_recurrent_input_to_output_w[0],
+ &output_hidden_kernel_sum[0],
+ &lstm_2_output_gate_bias[0],
+ ARM_SIGMOID};
+
+ // LSTM DATA
+ const cmsis_nn_lstm_params params = {LSTM_2_TIME_MAJOR,
+ LSTM_2_INPUT_BATCHES,
+ LSTM_2_TIME_STEPS,
+ LSTM_2_NUMBER_INPUTS,
+ LSTM_2_NUMBER_UNITS,
+ LSTM_2_DATA_OFFSET,
+ LSTM_2_FORGET_MULTIPLIER,
+ LSTM_2_FORGET_SHIFT,
+ LSTM_2_INPUT_MULTIPLIER,
+ LSTM_2_INPUT_SHIFT,
+ LSTM_2_IN_ACTIVATION_MAX,
+ LSTM_2_CELL_STATE_SHIFT,
+ LSTM_2_HIDDEN_MULTIPLIER,
+ LSTM_2_HIDDEN_SHIFT,
+ LSTM_2_HIDDEN_OFFSET,
+ gate_forget,
+ gate_input,
+ gate_cell,
+ gate_output};
+
+ // BUFFERS
+ cmsis_nn_lstm_context buffers;
+ buffers.temp1 = buffer1;
+ buffers.temp2 = buffer2;
+ buffers.cell_state = buffer3;
+
+ int8_t output[LSTM_2_DST_SIZE] = {0};
+ arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_2_input, output, ¶ms, &buffers);
+
+ TEST_ASSERT_EQUAL(expected, result);
+ TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
+}
+
+void lstm_one_time_step_arm_lstm_unidirectional_s8(void)
+{
+ int8_t output[LSTM_ONE_TIME_STEP_DST_SIZE] = {0};
+ const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
+ const int8_t *output_ref = lstm_one_time_step_output_ref;
+ const int32_t output_ref_size = LSTM_ONE_TIME_STEP_DST_SIZE;
+
+ // Calculate kernel sums if using MVE-extension
+ int32_t input_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t forget_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t cell_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t output_data_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+
+ int32_t input_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t forget_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t cell_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+ int32_t output_hidden_kernel_sum[LSTM_ONE_TIME_STEP_NUMBER_UNITS];
+
+ int32_t size_data = LSTM_ONE_TIME_STEP_NUMBER_INPUTS;
+ int32_t size_hidden = LSTM_ONE_TIME_STEP_NUMBER_UNITS;
+
+ arm_vector_sum_s8(&input_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_one_time_step_input_to_input_w[0],
+ LSTM_ONE_TIME_STEP_DATA_OFFSET,
+ &lstm_one_time_step_input_gate_bias[0]);
+ arm_vector_sum_s8(&forget_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_one_time_step_input_to_forget_w[0],
+ LSTM_ONE_TIME_STEP_DATA_OFFSET,
+ &lstm_one_time_step_forget_gate_bias[0]);
+ arm_vector_sum_s8(&cell_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_one_time_step_input_to_cell_w[0],
+ LSTM_ONE_TIME_STEP_DATA_OFFSET,
+ &lstm_one_time_step_cell_gate_bias[0]);
+ arm_vector_sum_s8(&output_data_kernel_sum[0],
+ size_data,
+ size_hidden,
+ &lstm_one_time_step_input_to_output_w[0],
+ LSTM_ONE_TIME_STEP_DATA_OFFSET,
+ &lstm_one_time_step_output_gate_bias[0]);
+
+ arm_vector_sum_s8(&input_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_one_time_step_recurrent_input_to_input_w[0],
+ -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&forget_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_one_time_step_recurrent_input_to_forget_w[0],
+ -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&cell_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_one_time_step_recurrent_input_to_cell_w[0],
+ -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET,
+ NULL);
+ arm_vector_sum_s8(&output_hidden_kernel_sum[0],
+ size_hidden,
+ size_hidden,
+ &lstm_one_time_step_recurrent_input_to_output_w[0],
+ -LSTM_ONE_TIME_STEP_HIDDEN_OFFSET,
+ NULL);
+
+ // INPUT GATE
+ const cmsis_nn_lstm_gate gate_input = {LSTM_ONE_TIME_STEP_IN_TO_INPUT_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_IN_TO_INPUT_SHIFT,
+ &lstm_one_time_step_input_to_input_w[0],
+ input_data_kernel_sum,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_INPUT_SHIFT,
+ &lstm_one_time_step_recurrent_input_to_input_w[0],
+ input_hidden_kernel_sum,
+ &lstm_one_time_step_input_gate_bias[0],
+ ARM_SIGMOID};
+
+ // FORGET GATE
+ const cmsis_nn_lstm_gate gate_forget = {LSTM_ONE_TIME_STEP_IN_TO_FORGET_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_IN_TO_FORGET_SHIFT,
+ &lstm_one_time_step_input_to_forget_w[0],
+ forget_data_kernel_sum,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_FORGET_SHIFT,
+ &lstm_one_time_step_recurrent_input_to_forget_w[0],
+ forget_hidden_kernel_sum,
+ &lstm_one_time_step_forget_gate_bias[0],
+ ARM_SIGMOID};
+
+ // CELL GATE
+ const cmsis_nn_lstm_gate gate_cell = {LSTM_ONE_TIME_STEP_IN_TO_CELL_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_IN_TO_CELL_SHIFT,
+ &lstm_one_time_step_input_to_cell_w[0],
+ cell_data_kernel_sum,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_CELL_SHIFT,
+ &lstm_one_time_step_recurrent_input_to_cell_w[0],
+ cell_hidden_kernel_sum,
+ &lstm_one_time_step_cell_gate_bias[0],
+ ARM_TANH};
+
+ // OUTPUT GATE
+ const cmsis_nn_lstm_gate gate_output = {LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_IN_TO_OUTPUT_SHIFT,
+ &lstm_one_time_step_input_to_output_w[0],
+ output_data_kernel_sum,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_RECURRENT_TO_OUTPUT_SHIFT,
+ &lstm_one_time_step_recurrent_input_to_output_w[0],
+ output_hidden_kernel_sum,
+ &lstm_one_time_step_output_gate_bias[0],
+ ARM_SIGMOID};
+
+ // LSTM DATA
+ const cmsis_nn_lstm_params params = {LSTM_ONE_TIME_STEP_TIME_MAJOR,
+ LSTM_ONE_TIME_STEP_INPUT_BATCHES,
+ LSTM_ONE_TIME_STEP_TIME_STEPS,
+ LSTM_ONE_TIME_STEP_NUMBER_INPUTS,
+ LSTM_ONE_TIME_STEP_NUMBER_UNITS,
+ LSTM_ONE_TIME_STEP_DATA_OFFSET,
+ LSTM_ONE_TIME_STEP_FORGET_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_FORGET_SHIFT,
+ LSTM_ONE_TIME_STEP_INPUT_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_INPUT_SHIFT,
+ LSTM_ONE_TIME_STEP_IN_ACTIVATION_MAX,
+ LSTM_ONE_TIME_STEP_CELL_STATE_SHIFT,
+ LSTM_ONE_TIME_STEP_HIDDEN_MULTIPLIER,
+ LSTM_ONE_TIME_STEP_HIDDEN_SHIFT,
+ LSTM_ONE_TIME_STEP_HIDDEN_OFFSET,
+ gate_forget,
+ gate_input,
+ gate_cell,
+ gate_output};
+
+ // BUFFERS
+ cmsis_nn_lstm_context buffers;
+ buffers.temp1 = buffer1;
+ buffers.temp2 = buffer2;
+ buffers.cell_state = buffer3;
+
+ arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_one_time_step_input, output, ¶ms, &buffers);
+
+ TEST_ASSERT_EQUAL(expected, result);
+ TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size));
+}
diff --git a/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c b/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c
index 43f7d26e..857e35ca 100644
--- a/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c
+++ b/Tests/UnitTest/TestCases/test_arm_svdf_s8/test_arm_svdf_s8.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates
+ * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -76,7 +76,7 @@ void svdf_int8_arm_svdf_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *kernel_sum_buf = ctx.buf;
- arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data);
+ arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL);
#endif
// + SVDF_INT8_TIME_BATCHES additional bytes to make sure it is not overwritten
@@ -191,7 +191,7 @@ void svdf_int8_2_arm_svdf_s8(void)
#if defined(ARM_MATH_MVEI)
int32_t *kernel_sum_buf = ctx.buf;
- arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data);
+ arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL);
#endif
const int state_data_size = sizeof(svdf_int8_2_state);
diff --git a/Tests/UnitTest/generate_test_data.py b/Tests/UnitTest/generate_test_data.py
index b82c53d7..c1e6ca6d 100755
--- a/Tests/UnitTest/generate_test_data.py
+++ b/Tests/UnitTest/generate_test_data.py
@@ -721,7 +721,8 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases,
dilation_x=3,
dilation_y=3,
pad=True,
- interpreter=interpreter) dataset = 'basic_int4'
+ interpreter=interpreter)
+ dataset = 'basic_int4'
testdata_sets[dataset] = ConvSettings(dataset,
type_of_test,
regenerate_weights,
@@ -2848,7 +2849,7 @@ def load_testdata_sets(regenerate_input, regenerate_weights, regenerate_biases,
regenerate_input,
regenerate_biases,
schema_file,
- batches=2,
+ batches=1,
time_steps=9,
number_inputs=6,
number_units=7,
diff --git a/Tests/UnitTest/lstm_settings.py b/Tests/UnitTest/lstm_settings.py
index 7846a571..4a0ce39a 100644
--- a/Tests/UnitTest/lstm_settings.py
+++ b/Tests/UnitTest/lstm_settings.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates
+# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -252,6 +252,7 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias
self.generate_c_array("output_norm_coeff", interpreter.get_tensor(output_norm_coeff['index']))
input_scale = input_data_for_index['quantization_parameters']['scales'][0]
+ self.data_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
cell_scale = cell_state['quantization_parameters']['scales'][0]
output_state_scale = output_state['quantization_parameters']['scales'][0]
input_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
@@ -263,7 +264,7 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias
tmp = math.log(cell_scale) * (1 / math.log(2))
self.cell_state_shift = int(round(tmp))
- self.calc_scales(input_scale, output_state_scale)
+ self.calc_scales(input_scale, output_state_scale, cell_scale)
# Calculate effective biases.
input_zp = -input_zp
@@ -293,14 +294,21 @@ def generate_data(self, input_data=None, weights=None, hidden_weights=None, bias
self.generate_c_array("recurrent_to_output_eff_bias", recurrent_to_output_eff_bias, datatype='int32_t')
# Generate reference
- interpreter.invoke()
- output_data = interpreter.get_tensor(output_details[0]["index"])
+ if self.use_tflite_micro_interpreter:
+ interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite))
+ interpreter.set_input(tf.cast(input_data, tf.int8), input_details[0]["index"])
+ interpreter.invoke()
+ output_data = interpreter.get_output(0)
+ else:
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]["index"])
+
self.generate_c_array(self.output_data_file_prefix, output_data, datatype='int8_t')
self.write_c_config_header()
self.write_c_header_wrapper()
- def calc_scales(self, input_scale, output_state_scale):
+ def calc_scales(self, input_scale, output_state_scale, cell_scale):
intermediate_scale = pow(2, -12)
if self.time_major:
@@ -308,6 +316,9 @@ def calc_scales(self, input_scale, output_state_scale):
else:
time_major_offset = 0
+
+ self.effective_forget_scale = pow(2, -15) / cell_scale * cell_scale
+ self.effective_input_scale = pow(2, -15) / cell_scale * pow(2, -15)
self.effective_hidden_scale = pow(2, -15) / output_state_scale * pow(2, -15)
self.i2i_effective_scale = input_scale * self.lstm_scales[self.input_to_input_w_index + time_major_offset][0] \
@@ -393,11 +404,21 @@ def write_c_config_header(self) -> None:
f.write("#define {}_RECURRENT_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier))
f.write("#define {}_RECURRENT_TO_OUTPUT_SHIFT {}\n".format(prefix, shift))
+
+ (multiplier, shift) = self.quantize_scale(self.effective_forget_scale)
+ f.write("#define {}_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
+ f.write("#define {}_FORGET_SHIFT {}\n".format(prefix, shift))
+
+ (multiplier, shift) = self.quantize_scale(self.effective_input_scale)
+ f.write("#define {}_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
+ f.write("#define {}_INPUT_SHIFT {}\n".format(prefix, shift))
+
(multiplier, shift) = self.quantize_scale(self.effective_hidden_scale)
f.write("#define {}_HIDDEN_MULTIPLIER {}\n".format(prefix, multiplier))
f.write("#define {}_HIDDEN_SHIFT {}\n".format(prefix, shift))
f.write("#define {}_HIDDEN_OFFSET {}\n".format(prefix, self.hidden_zp))
+ f.write("#define {}_DATA_OFFSET {}\n".format(prefix, -self.data_zp))
f.write("#define {}_OUTPUT_STATE_OFFSET {}\n".format(prefix, self.output_state_offset))
f.write("#define {}_CELL_STATE_SHIFT {}\n".format(prefix, self.cell_state_shift))