Skip to content

Commit

Permalink
Reimplement arm_lstm_unidirectional_s8 (ARM-software#102)
Browse files Browse the repository at this point in the history
- API changes
- Optimized for scalar, DSP and MVE
- Bit exact to TFLM reference kernel
- Less scratch-buffer usage
  • Loading branch information
AdrianLundell authored Feb 7, 2024
1 parent ffeca90 commit 601d96c
Show file tree
Hide file tree
Showing 133 changed files with 1,830 additions and 1,889 deletions.
9 changes: 4 additions & 5 deletions ARM.CMSIS-NN.pdsc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
<file category="source" name="Source/PoolingFunctions/arm_avgpool_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s8.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_acc_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_add_s8.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_add_s16.c"/>
<file category="source" name="Source/BasicMathFunctions/arm_elementwise_mul_s16_s8.c"/>
Expand All @@ -107,18 +108,16 @@
<file category="source" name="Source/NNSupportFunctions/arm_nntables.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_depthwise_conv_nt_t_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_mat_mul_core_1x_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_update_output_s8_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_step_s8_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_update_cell_state_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_step_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_lstm_calculate_gate_s8_s16.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8.c"/>
<file category="source" name="Source/NNSupportFunctions/arm_nn_vec_mat_mul_result_acc_s8_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s4.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s8.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_get_buffer_sizes_s16.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_fully_connected_get_buffer_sizes_s8.c"/>
<file category="source" name="Source/FullyConnectedFunctions/arm_vector_sum_s8.c"/>
<file category="source" name="Source/LSTMFunctions/arm_lstm_unidirectional_s8_s16.c"/>
<file category="source" name="Source/LSTMFunctions/arm_lstm_unidirectional_s8.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_softmax_s8.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_nn_softmax_common_s8.c"/>
<file category="source" name="Source/SoftmaxFunctions/arm_softmax_s8_s16.c"/>
Expand Down
134 changes: 49 additions & 85 deletions Include/arm_nn_types.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <[email protected]>
* SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <[email protected]>
*
* SPDX-License-Identifier: Apache-2.0
*
Expand All @@ -22,8 +22,8 @@
* Description: Public header file to contain the CMSIS-NN structs for the
* TensorFlowLite micro compliant functions
*
* $Date: 9 January 2024
* $Revision: V.2.6.2
* $Date: 19 January 2024
* $Revision: V.3.0.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand All @@ -40,7 +40,6 @@
* @{
*/


/** Enum for specifying activation function types */
typedef enum
{
Expand Down Expand Up @@ -180,102 +179,67 @@ typedef struct
const int16_t *one_by_one_lut;
} cmsis_nn_softmax_lut_s16;

/** LSTM guard parameters */
typedef struct
{
int32_t input_variance;
int32_t forget_variance;
int32_t cell_variance;
int32_t output_variance;
} cmsis_nn_lstm_guard_params;

/** LSTM scratch buffer container */
typedef struct
{
int16_t *input_gate;
int16_t *forget_gate;
int16_t *cell_gate;
int16_t *output_gate;
} cmsis_nn_lstm_context;

/** Quantized clip value for cell and projection of LSTM input. Zero value means no clipping. */
typedef struct
{
int16_t cell;
int8_t projection;
} cmsis_nn_lstm_clip_params;

/** CMSIS-NN object for quantization parameters */
typedef struct
{
int32_t multiplier; /**< Multiplier value */
int32_t shift; /**< Shift value */
} cmsis_nn_scaling;

/** CMSIS-NN norm layer coefficients */
/** CMSIS-NN object for LSTM gate parameters*/
typedef struct
{
int16_t *input_weight;
int16_t *forget_weight;
int16_t *cell_weight;
int16_t *output_weight;
} cmsis_nn_layer_norm;
int32_t input_multiplier;
int32_t input_shift;
const int8_t *input_weights;
const int32_t *input_effective_bias; /**< Bias added with precomputed kernel_sum * lhs_offset*/

/** Parameters for integer LSTM, as defined in TFLM */
int32_t hidden_multiplier;
int32_t hidden_shift;
const int8_t *hidden_weights;
const int32_t *hidden_effective_bias; /**< Precomputed kernel_sum * lhs_offset*/

const int32_t *bias;
arm_nn_activation_type activation_type;
} cmsis_nn_lstm_gate;

/** CMSIS-NN object for LSTM parameters*/
typedef struct
{
int32_t time_major; /**< Nonzero (true) if first row of data is timestamps for input */
cmsis_nn_scaling input_to_input_scaling;
cmsis_nn_scaling input_to_forget_scaling;
cmsis_nn_scaling input_to_cell_scaling;
cmsis_nn_scaling input_to_output_scaling;
cmsis_nn_scaling recurrent_to_input_scaling;
cmsis_nn_scaling recurrent_to_forget_scaling;
cmsis_nn_scaling recurrent_to_cell_scaling;
cmsis_nn_scaling recurrent_to_output_scaling;
cmsis_nn_scaling cell_to_input_scaling;
cmsis_nn_scaling cell_to_forget_scaling;
cmsis_nn_scaling cell_to_output_scaling;
cmsis_nn_scaling projection_scaling;
cmsis_nn_scaling hidden_scaling;
cmsis_nn_scaling layer_norm_input_scaling; /**< layer normalization for input layer */
cmsis_nn_scaling layer_norm_forget_scaling; /**< layer normalization for forget gate */
cmsis_nn_scaling layer_norm_cell_scaling; /**< layer normalization for cell */
cmsis_nn_scaling layer_norm_output_scaling; /**< layer normalization for outpus layer */

int32_t cell_state_shift;
int32_t hidden_offset;
int32_t output_state_offset;

cmsis_nn_lstm_clip_params clip;
cmsis_nn_lstm_guard_params guard;
cmsis_nn_layer_norm layer_norm;

/* Effective bias is precalculated as bias + zero_point * weight.
Only applicable to when input/output are s8 and weights are s16 */
const int32_t *i2i_effective_bias; /**< input to input effective bias */
const int32_t *i2f_effective_bias; /**< input to forget gate effective bias */
const int32_t *i2c_effective_bias; /**< input to cell effective bias */
const int32_t *i2o_effective_bias; /**< input to output effective bias */

const int32_t *r2i_effective_bias; /**< recurrent gate to input effective bias */
const int32_t *r2f_effective_bias; /**< recurrent gate to forget gate effective bias */
const int32_t *r2c_effective_bias; /**< recurrent gate to cell effective bias */
const int32_t *r2o_effective_bias; /**< recurrent gate to output effective bias */

const int32_t *projection_effective_bias;

/* Not precalculated bias */
const int32_t *input_gate_bias;
const int32_t *forget_gate_bias;
const int32_t *cell_gate_bias;
const int32_t *output_gate_bias;

/* Activation min and max */
cmsis_nn_activation activation;
int32_t time_major; /**< 0 if first dimension is batch, else first dimension is time */
int32_t batch_size;
int32_t time_steps;
int32_t input_size; /**< Size of new data input into the LSTM cell*/
int32_t
hidden_size; /**< Size of output from the LSTM cell, used as output and recursively into the next time step*/

int32_t input_offset;

int32_t forget_to_cell_multiplier;
int32_t forget_to_cell_shift;
int32_t input_to_cell_multiplier;
int32_t input_to_cell_shift;
int32_t cell_clip; /**< Min/max value of cell output*/
int32_t cell_scale_power;

int32_t output_multiplier;
int32_t output_shift;
int32_t output_offset;

cmsis_nn_lstm_gate forget_gate;
cmsis_nn_lstm_gate input_gate;
cmsis_nn_lstm_gate cell_gate;
cmsis_nn_lstm_gate output_gate;
} cmsis_nn_lstm_params;

/** CMSIS-NN object for LSTM scratch buffers*/
typedef struct
{
void *temp1;
void *temp2;
void *cell_state;
} cmsis_nn_lstm_context;

/**
* @} // end group genPubTypes
*/
Expand Down
101 changes: 33 additions & 68 deletions Include/arm_nnfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
* Title: arm_nnfunctions.h
* Description: Public header file for CMSIS NN Library
*
* $Date: 11 January 2024
* $Revision: V.12.6.0
* $Date: 19 January 2024
* $Revision: V.13.0.0
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand Down Expand Up @@ -1514,19 +1515,23 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
int8_t *output_data);

/**
* @brief Calculate vector sums that may be required by arm_fully_connected_s8().
* @brief Calculate the sum of each row in vector_data, multiply by lhs_offset and optionally add bias_data.
* @param[in, out] vector_sum_buf Buffer for vector sums
* @param[in] vector_cols Number of vector columns
* @param[in] vector_rows Number of vector rows
* @param[in] vector_data Vector or weigths data
* @param[in] vector_data Vector of weigths data
* @param[in] lhs_offset Constant multiplied with each sum
* @param[in] bias_data Vector of bias data, added to each sum.
* @return The function returns
* <code>ARM_CMSIS_NN_SUCCESS</code> - Successful operation
* <code>ARM_CMSIS_NN_ARG_ERROR</code> - If not for Arm(R) Helium Architecture case.
*/
arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_cols,
const int32_t vector_rows,
const int8_t *vector_data);
const int8_t *vector_data,
const int32_t lhs_offset,
const int32_t *bias_data);

/**
* @brief Get size of additional buffer required by arm_fully_connected_s8().
Expand Down Expand Up @@ -1802,21 +1807,23 @@ void arm_relu_q15(int16_t *data, uint16_t size);

/**
* @brief s16 neural network activation function using direct table look-up
* @param[in] input pointer to input data
* @param[in] input pointer to input data
* @param[out] output pointer to output
* @param[in] size number of elements
* @param[in] left_shift bit-width of the integer part, assume to be smaller than 3
* @param[in] left_shift bit-width of the integer part, assumed to be smaller than 3.
* @param[in] type type of activation functions
* @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
*
* @details Supported framework: TensorFlow Lite for Microcontrollers.
* This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid actication
* This activation function must be bit precise congruent with the corresponding TFLM tanh and sigmoid activation
* functions
*/
void arm_nn_activation_s16(const int16_t *input,
int16_t *output,
const uint16_t size,
const uint16_t left_shift,
const arm_nn_activation_type type);
arm_cmsis_nn_status arm_nn_activation_s16(const int16_t *input,
int16_t *output,
const int32_t size,
const int32_t left_shift,
const arm_nn_activation_type type);

/**
* @defgroup Pooling Pooling Functions
Expand Down Expand Up @@ -2441,67 +2448,25 @@ arm_cmsis_nn_status arm_svdf_state_s16_s8(const cmsis_nn_context *input_ctx,
*/

/**
* @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output
* Peephole connections, projection, clipping, combined input/forget gate and layer normalization are not supported.
*
* @param[in] scratch_buffers Struct containing scratch buffers
* Expected size for each scratch buffer is
* lstm_dims->num_batches * lstm_dims->num_outputs.
* @param[in] input_data Pointer to input data
* @param[in] lstm_dims LSTM input parameters related to dimensions
* @param[in] input_to_input_weights Input to input weights
* @param[in] input_to_forget_weights Input to forget weights
* @param[in] input_to_cell_weights Input to cell weights
* @param[in] input_to_output_weights Input to output weights
* @param[in] recurrent_to_input_weights Recurrent to input weights
* @param[in] recurrent_to_forget_weights Recurrent to forget weights
* @param[in] recurrent_to_cell_weights Recurrent to cell weights
* @param[in] recurrent_to_output_weights Recurrent to output weights
* @param[in] cell_to_input_weights Cell to input weights. Not used.
* @param[in] cell_to_forget_weights Cell to forget weights. Not used.
* @param[in] cell_to_output_weights Cell to output weights. Not used.
* @param[in] projection_weights Projection weights. Not used.
* @param[in] lstm LSTM parameters. See struct declaration
* @param[in] output_state Pointer to (recurrent) output state
* @param[in] cell_state Pointer to cell state
* @param[in] output_data Pointer to output state
*
* @note Following assumptions are done based on LSTM functionality as supported by
* Keras version 2.9.0 at the time of development. As stated here,
* https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md
* Keras's LSTMCell is equivalent to TensorFlow's BasicLSTMCell,
* which does not support peephole, clipping or projection.
* Layer normalization and combined input/forget gate are not supported either.
*
* 1 Input to input weight can not be nullptr. Otherwise nullptr for combined input/forgat gate.
* 2 Cell weights are not used and should be nullptr. Otherwise needed for peephole connections.
* 3 Projection weight is not used and should be nullpr. Otherwise needed for projection.
* @brief LSTM unidirectional function with 8 bit input and output and 16 bit gate output.
*
* @param[in] input Pointer to input data
* @param[out] output Pointer to output data
* @param[in] params Struct containing all information about the lstm operator, see arm_nn_types.
* @param[in] buffers Struct containing pointers to all temporary scratch buffers needed for the
* lstm operator, see arm_nn_types.
*
*
* @return The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
*
* @details
* 1. Supported framework: TensorFlow Lite micro
* 1. Supported framework: TensorFlow Lite Micro
*
*/
arm_cmsis_nn_status arm_lstm_unidirectional_s16_s8(cmsis_nn_lstm_context *scratch_buffers,
const int8_t *input_data,
const cmsis_nn_lstm_dims *lstm_dims,
const int8_t *input_to_input_weights,
const int8_t *input_to_forget_weights,
const int8_t *input_to_cell_weights,
const int8_t *input_to_output_weights,
const int8_t *recurrent_to_input_weights,
const int8_t *recurrent_to_forget_weights,
const int8_t *recurrent_to_cell_weights,
const int8_t *recurrent_to_output_weights,
const int16_t *cell_to_input_weights,
const int16_t *cell_to_forget_weights,
const int16_t *cell_to_output_weights,
const int8_t *projection_weights,
const cmsis_nn_lstm_params *lstm,
int8_t *output_state,
int16_t *cell_state,
int8_t *output_data);
arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
int8_t *output,
const cmsis_nn_lstm_params *params,
cmsis_nn_lstm_context *buffers);

/**
* @brief Get size of additional buffer required by arm_svdf_s8().
Expand Down
Loading

0 comments on commit 601d96c

Please sign in to comment.