diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 5c610c15b4..ff13b02cc7 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -11,6 +11,7 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_broadcast_32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index bc7aa70185..62bb816f35 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -28,6 +28,11 @@ extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ const WORD32 *const p_inp1_shape, const FLOAT32 * __restrict__ p_inp2, const WORD32 *const p_inp2_shape); + +extern "C" void xa_nn_elm_atan2_f32(FLOAT32 * z, + const FLOAT32 * y, + const FLOAT32 * x, + WORD32 N ); extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, const WORD32 *const p_out_shape, diff --git a/backends/cadence/hifi/operators/op_atan2.cpp b/backends/cadence/hifi/operators/op_atan2.cpp index 6d804fd230..5d6f7c360a 100644 --- a/backends/cadence/hifi/operators/op_atan2.cpp +++ b/backends/cadence/hifi/operators/op_atan2.cpp @@ -58,7 +58,7 @@ Tensor& atan2_out( WORD32 num_elm = out.numel(); - if (!optimized) { + if (optimized) { if (broadcast) { WORD32* __restrict__ ptr1 = (WORD32* __restrict__)malloc(num_elm * sizeof(WORD32)); @@ -70,26 +70,16 @@ Tensor& atan2_out( WORD32* __restrict__ pin2 = (WORD32* __restrict__)b.const_data_ptr(); - WORD32 p_out_shape[max_dim]; - WORD32 p_inp1_shape[max_dim]; - WORD32 p_inp2_shape[max_dim]; - - for (int i = 0; i < kNnlibMaxDim; i++) { - p_inp1_shape[i] = 1; - p_inp2_shape[i] = 1; - p_out_shape[i] = 1; - } - - int off_o = max_dim - out_dim; - int off_a = max_dim - a_dim; - int off_b = max_dim - b_dim; + WORD32 p_out_shape[kNnlibMaxDim]; + WORD32 p_inp1_shape[kNnlibMaxDim]; + WORD32 p_inp2_shape[kNnlibMaxDim]; for (int i = 0; i < out_dim; i++) - p_out_shape[i + off_o] = out.size(i); + p_out_shape[i] = out.size(i); for (int i = 0; i < a_dim; i++) - p_inp1_shape[i + off_a] = a.size(i); + p_inp1_shape[i] = a.size(i); for (int i = 0; i < b_dim; i++) - p_inp2_shape[i + off_b] = b.size(i); + p_inp2_shape[i] = b.size(i); xa_nn_broadcast_32_32(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); @@ -100,7 +90,7 @@ Tensor& atan2_out( const FLOAT32* __restrict__ p_inp1 = (const FLOAT32* __restrict__)ptr1; const FLOAT32* __restrict__ p_inp2 = (const FLOAT32* __restrict__)ptr2; - vecatan2f(p_out, p_inp1, p_inp2, num_elm); + xa_nn_elm_atan2_f32(p_out, p_inp1, p_inp2, num_elm); free(ptr1); free(ptr2); @@ -111,21 +101,13 @@ Tensor& atan2_out( FLOAT32* __restrict__ pin1 = (FLOAT32* __restrict__)a.const_data_ptr(); - WORD32 p_out_shape[max_dim]; - WORD32 p_inp1_shape[max_dim]; - - for (int i = 0; i < max_dim; i++) { - p_inp1_shape[i] = 1; - p_out_shape[i] = 1; - } - - int off_o = max_dim - out_dim; - int off_a = max_dim - a_dim; + WORD32 p_out_shape[kNnlibMaxDim]; + WORD32 p_inp1_shape[kNnlibMaxDim]; for (int i = 0; i < out_dim; i++) - p_out_shape[i + off_o] = out.size(i); + p_out_shape[i] = out.size(i); for (int i = 0; i < a_dim; i++) - p_inp1_shape[i + off_a] = a.size(i); + p_inp1_shape[i] = a.size(i); xa_nn_broadcast_32_32( (WORD32*)ptr1, p_out_shape, (WORD32*)pin1, p_inp1_shape, out_dim); @@ -136,7 +118,7 @@ Tensor& atan2_out( const FLOAT32* __restrict__ p_inp2 = (const FLOAT32* __restrict__)b.const_data_ptr(); - vecatan2f(p_out, p_inp1, p_inp2, num_elm); + xa_nn_elm_atan2_f32(p_out, p_inp1, p_inp2, num_elm); free(ptr1); } else if (b_is_broadcasted && (!a_is_broadcasted)) { @@ -146,21 +128,13 @@ Tensor& atan2_out( WORD32* __restrict__ pin1 = (WORD32* __restrict__)b.const_data_ptr(); - WORD32 p_out_shape[max_dim]; - WORD32 p_inp1_shape[max_dim]; - - for (int i = 0; i < max_dim; i++) { - p_inp1_shape[i] = 1; - p_out_shape[i] = 1; - } - - int off_o = max_dim - out_dim; - int off_b = max_dim - b_dim; + WORD32 p_out_shape[kNnlibMaxDim]; + WORD32 p_inp1_shape[kNnlibMaxDim]; for (int i = 0; i < out_dim; i++) - p_out_shape[i + off_o] = out.size(i); + p_out_shape[i] = out.size(i); for (int i = 0; i < b_dim; i++) - p_inp1_shape[i + off_b] = b.size(i); + p_inp1_shape[i] = b.size(i); xa_nn_broadcast_32_32(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); @@ -170,7 +144,7 @@ Tensor& atan2_out( (const FLOAT32* __restrict__)a.const_data_ptr(); const FLOAT32* __restrict__ p_inp2 = (const FLOAT32* __restrict__)ptr1; - vecatan2f(p_out, p_inp1, p_inp2, num_elm); + xa_nn_elm_atan2_f32(p_out, p_inp1, p_inp2, num_elm); free(ptr1); } else { @@ -181,7 +155,7 @@ Tensor& atan2_out( const FLOAT32* __restrict__ p_inp2 = (const FLOAT32* __restrict__)b.const_data_ptr(); - vecatan2f(p_out, p_inp1, p_inp2, num_elm); + xa_nn_elm_atan2_f32(p_out, p_inp1, p_inp2, num_elm); } return out; } diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c new file mode 100644 index 0000000000..6f95360ed9 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c @@ -0,0 +1,882 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2018 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ("Cadence */ +/* Libraries") are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* DSP Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2015-2018 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +#include + +#include "../include/NatureDSP_Signal_math.h" +#include "NatureDSP_types.h" +#include "xa_nn_common.h" + +/* Common helper macros. */ +#include "xa_nnlib_common_fpu.h" + +#include "xa_nnlib_common.h" + +const union ufloat32uint32 xa_nnlib_plusInff ={0x7f800000}; +const union ufloat32uint32 xa_nnlib_qNaNf = { 0x7fc00000 }; +const union ufloat32uint32 pif ={0x40490fdb}; /* pi */ +const union ufloat32uint32 pi2f={0x3fc90fdb}; /* pi/2 */ + +const union ufloat32uint32 ALIGN(8) xa_nnlib_atanftbl1[8] = +{ + {0x3dbc14c0},/* 9.183645248413086e-002 */ + {0xbe30c39c},/*-1.726211905479431e-001 */ + {0x3b2791e4},/* 2.556913532316685e-003 */ + {0x3e4dac9d},/* 2.008537799119949e-001 */ + {0xb97d9a57},/*-2.418545627733693e-004 */ + {0xbeaaa7b5},/*-3.333107531070709e-001 */ + {0xb54f34c8},/*-7.719031600572635e-007 */ + {0x31cf3fa2} /* 6.031727117772334e-009 */ +}; + +const union ufloat32uint32 ALIGN(8) xa_nnlib_atanftbl2[8]= +{ + {0xbcccc037},/*-2.499399892985821e-002 */ + {0x3e217c35},/* 1.577003747224808e-001 */ + {0xbecf4163},/*-4.047957360744476e-001 */ + {0x3ef7b762},/* 4.838209748268127e-001 */ + {0xbdf35059},/*-1.188055947422981e-001 */ + {0xbe9b8b75},/*-3.037983477115631e-001 */ + {0xbb80ed5c},/*-3.934545442461968e-003 */ + {0x3956fc52} /* 2.050262701231986e-004 */ +}; + +#if !HAVE_VFPU && !HAVE_FPU +DISCARD_FUN(void, xa_nn_elm_atan2_f32,( FLOAT32 * z, const FLOAT32 * y, const FLOAT32 * x, int N )) +#elif HAVE_VFPU +#define sz_f32 (int)sizeof(FLOAT32) + +/*=========================================================================== + Vector matematics: + vec_atan2 full quadrant Arctangent +===========================================================================*/ + +/*------------------------------------------------------------------------- + Full-Quadrant Arc Tangent + The functions compute the arc tangent of the ratios y[N]/x[N] and store the + result to output vector z[N]. + Floating point functions output is in radians. Fixed point functions + scale its output by pi. + + NOTE: + 1. Scalar floating point function is compatible with standard ANSI C routines and set + errno and exception flags accordingly + 2. Scalar floating point function assigns EDOM to errno whenever y==0 and x==0. + + Accuracy: + 24 bit version: 768 (3.57e-7) + floating point: 2 ULP + + Special cases: + y | x | result | extra conditions + --------|-------|-----------|--------------------- + +/-0 | -0 | +/-pi | + +/-0 | +0 | +/-0 | + +/-0 | x | +/-pi | x<0 + +/-0 | x | +/-0 | x>0 + y | +/-0 | -pi/2 | y<0 + y | +/-0 | pi/2 | y>0 + +/-y | -inf | +/-pi | finite y>0 + +/-y | +inf | +/-0 | finite y>0 + +/-inf | x | +/-pi/2 | finite x + +/-inf | -inf | +/-3*pi/4 | + +/-inf | +inf | +/-pi/4 | + + Input: + y[N] vector of numerator values, Q31 or floating point + x[N] vector of denominator values, Q31 or floating point + N length of vectors + Output: + z[N] results, Q31 or floating point + +---------------------------------------------------------------------------*/ + +void xa_nn_elm_atan2_f32( FLOAT32 * z, const FLOAT32 * y, const FLOAT32 * x, WORD32 N ) +{ + /* + const union ufloat32uint32* p; + int sx,sy,big; + sx=takesignf(x); + sy=takesignf(y); + x=fabs(x); + y=fabs(y); + if(x==0.f && y==0.f) + { + // The actual result depends on input signs. + x = 1.f; + y = 0.f; + } + + big=x>y; + if(big) + { + x=y/x; + } + else + { + // compare x==y is necessary to support (+/-Inf, +/-Inf) cases + x = (x == y) ? 1.0f : x / y; + } + p = (x<0.5f) ? atanftbl1 : atanftbl2; + // approximate atan(x)/x-1 + y = p[0].f; + y = x*y + p[1].f; + y = x*y + p[2].f; + y = x*y + p[3].f; + y = x*y + p[4].f; + y = x*y + p[5].f; + y = x*y + p[6].f; + y = x*y + p[7].f; + // convert result to true atan(x) + y = x*y + x; + + if (!big) y = pi2f.f - y; + if (sx) y = pif.f - y; + if (sy) y = -y; + return y; + */ + + const xtfloatx2 * X; + const xtfloatx2 * Y; + xtfloatx2 * restrict Z; + const xtfloatx2 * S_rd; + xtfloatx2 * restrict S_wr; + + ae_valign X_va, Y_va, Z_va; + + /* Current block index; overall number of blocks; number of values in the current block */ + int blkIx, blkNum, blkLen; + /* Block size, blkLen <= blkSize */ + const int blkSize = MAX_ALLOCA_SZ/sz_f32; + /* Allocate a fixed-size scratch area on the stack. */ + FLOAT32 ALIGN(8) scr[blkSize]; + + int n; + + if ( N<=0 ) return; + + NASSERT_ALIGN8( scr ); + + /* + * Data are processed in blocks of scratch area size. Further, the algorithm + * implementation is splitted in order to feed the optimizing compiler with a + * few loops of managable size. + */ + + blkNum = ( N + blkSize-1 )/blkSize; + + for ( blkIx=0; blkIxy0 ) p0 = y0/x0; + * // Special case of x==y is necessary to support (+/-Inf, +/-Inf) cases. + * else p0 = ( x0==y0 ? 1.f : x0/y0 ); + * + * scr[n] = p0; + * } + * } + */ + + { + /* Input values */ + xtfloatx2 x0, y0; + /* Numerator; denominator; reciprocal; quotient */ + xtfloatx2 num, den, rcp, quo; + /* Scaling factor; error term */ + xtfloatx2 scl, eps; + /* Is NaN; Inf/Inf; x/Inf; 0/0; x and y are subnormal */ + xtbool2 b_nan, b_num_inf, b_den_inf, b_eqz, b_subn; + + X = (xtfloatx2*)( (uintptr_t)x + blkIx*blkSize*sz_f32 ); + Y = (xtfloatx2*)( (uintptr_t)y + blkIx*blkSize*sz_f32 ); + S_wr = (xtfloatx2*)scr; + + X_va = XT_LASX2PP( X ); + Y_va = XT_LASX2PP( Y ); + + __Pragma( "loop_count min=1" ); + for ( n=0; n<(blkLen+1)/2; n++ ) + { + XT_LASX2IP( x0, X_va, X ); + XT_LASX2IP( y0, Y_va, Y ); + + /* Replicate NaNs in both x and y to ensure NaN propagation. */ + b_nan = XT_UN_SX2( x0, y0 ); + XT_MOVT_SX2( x0, xa_nnlib_qNaNf.f, b_nan ); + XT_MOVT_SX2( y0, xa_nnlib_qNaNf.f, b_nan ); + + x0 = XT_ABS_SX2( x0 ); + y0 = XT_ABS_SX2( y0 ); + + /* num <= den */ + num = XT_MIN_SX2( x0, y0 ); + den = XT_MAX_SX2( y0, x0 ); + + /* Scale up numerator and denominator if BOTH are subnormal. */ + b_subn = XT_OLT_SX2( num, FLT_MIN ); + scl = (xtfloatx2)8388608.f; XT_MOVF_SX2( scl, (xtfloatx2)1.0f, b_subn ); + num = XT_MUL_SX2( num, scl ); + den = XT_MUL_SX2( den, scl ); + + /* Classify numerator and denominator. */ + b_num_inf = XT_OEQ_SX2( num, xa_nnlib_plusInff.f ); /* Inf/Inf */ + b_den_inf = XT_OEQ_SX2( den, xa_nnlib_plusInff.f ); /* x/Inf */ + b_eqz = XT_OEQ_SX2( den, (xtfloatx2)(xtfloatx2)(0.0f) ); /* 0/0 */ + + /* Initial appromimation for 1/den. */ + rcp = XT_RECIP0_SX2( den ); + /* Newton-Raphson iteration for 1/den. */ + eps = (xtfloatx2)1.0f; + XT_MSUB_SX2( eps, rcp, den ); + XT_MADD_SX2( rcp, rcp, eps ); + /* Approximation for the quotient num/den. */ + quo = XT_MUL_SX2( num, rcp ); + /* Refine the quotient by a modified Newton-Raphson iteration. */ + eps = num; + XT_MSUB_SX2( eps, quo, den ); + XT_MADD_SX2( quo, rcp, eps ); + + /* Force conventional results for special cases. */ + XT_MOVT_SX2( quo, (xtfloatx2)(0.0f), b_den_inf ); /* x/Inf -> 0 */ + XT_MOVT_SX2( quo, (xtfloatx2)1.0f, b_num_inf ); /* Inf/Inf -> 1 */ + XT_MOVT_SX2( quo, (xtfloatx2)(0.0f), b_eqz ); /* 0/0 -> 0 */ + + XT_SSX2IP( quo, S_wr, +2*sz_f32 ); + } + } + + __Pragma( "no_reorder" ); + + /* + * Part II, polynomial approximation and full quadrant restoration. + * Reference C code: + * + * { + * const union ufloat32uint32 * ptbl; + * float32_t x0, y0, z0, p0; + * int sx, sy; + * + * for ( n=0; n0 + y | +/-0 | -pi/2 | y<0 + y | +/-0 | pi/2 | y>0 + +/-y | -inf | +/-pi | finite y>0 + +/-y | +inf | +/-0 | finite y>0 + +/-inf | x | +/-pi/2 | finite x + +/-inf | -inf | +/-3*pi/4 | + +/-inf | +inf | +/-pi/4 | + +Input: + y[N] input data, Q15 or floating point + x[N] input data, Q15 or floating point + N length of vectors +Output: + z[N] result, Q15 or floating point + +Restrictions: +x, y, z should not overlap +---------------------------------------------------------------------------*/ + +// Taken from Fusion +void xa_nn_elm_atan2_f32( FLOAT32 * z, const FLOAT32 * y, const FLOAT32 * x, WORD32 N ) +{ + /* + * const union ufloat32uint32* p; + * int sx,sy,big; + * sx=takesignf(x); + * sy=takesignf(y); + * x=fabs(x); + * y=fabs(y); + * if(x==0.f && y==0.f) + * { + * // The actual result depends on input signs. + * x = 1.f; + * y = 0.f; + * } + * + * big=x>y; + * if(big) + * { + * x=y/x; + * } + * else + * { + * // compare x==y is necessary to support (+/-Inf, +/-Inf) cases + * x = (x == y) ? 1.0f : x / y; + * } + * p = (x<0.5f) ? atanftbl1 : atanftbl2; + * // approximate atan(x)/x-1 + * y = p[0].f; + * y = x*y + p[1].f; + * y = x*y + p[2].f; + * y = x*y + p[3].f; + * y = x*y + p[4].f; + * y = x*y + p[5].f; + * y = x*y + p[6].f; + * y = x*y + p[7].f; + * // convert result to true atan(x) + * y = x*y + x; + * + * if (!big) y = pi2f.f - y; + * if (sx) y = pif.f - y; + * if (sy) y = -y; + * return y; + */ + const xtfloat * restrict X; + const xtfloat * restrict Y; + int32_t * restrict Z; + const xtfloat * restrict S_rd; + xtfloat * restrict S_wr; + const xtfloat * restrict POLY_TBL1; + const xtfloat * restrict POLY_TBL2; + + /* Current block index; overall number of blocks; number of values in the current block */ + int blkIx, blkNum, blkLen; + /* Block size, blkLen <= blkSize */ + const int blkSize = MAX_ALLOCA_SZ / sz_f32; + /* Allocate a fixed-size scratch area on the stack. */ + float32_t ALIGN(8) scr[blkSize]; + + int n; + + if (N <= 0) return; + + NASSERT_ALIGN8(scr); + + /* + * Data are processed in blocks of scratch area size. Further, the algorithm + * implementation is splitted in order to feed the optimizing compiler with a + * few loops of managable size. + */ + + blkNum = (N + blkSize - 1) / blkSize; + POLY_TBL1 = (xtfloat*)xa_nnlib_atanftbl1; + POLY_TBL2 = (xtfloat*)xa_nnlib_atanftbl2; + for (blkIx = 0; blkIxy0 ) p0 = y0/x0; + * // Special case of x==y is necessary to support (+/-Inf, +/-Inf) cases. + * else p0 = ( x0==y0 ? 1.f : x0/y0 ); + * + * scr[n] = p0; + * } + * } + */ + + { + /* Input values */ + xtfloat x0, y0, i0; + /* Numerator; denominator; reciprocal; quotient */ + xtfloat num, den, rcp, quo; + /* Auxiliary vars */ + xtfloat s, eps; + /* Is NaN; Inf/Inf; x/Inf; 0/0; x and y are subnormal */ + xtbool b_nan, b_num_inf, b_den_inf, b_eqz, b_subn; + const xtfloat * pT; + + X = (xtfloat*)((uintptr_t)x + blkIx*blkSize*sz_f32); + Y = (xtfloat*)((uintptr_t)y + blkIx*blkSize*sz_f32); + S_wr = (xtfloat*)scr; + + static const uint32_t TAB[4] = { 0x7fc00000, 0x00800000, + 0x4b000000, 0x7f800000 + }; + pT = (xtfloat *)TAB; + __Pragma("loop_count min=1"); + for (n = 0; n 0 or x/Inf -> 0*/ + XT_MOVT_S(quo, XT_CONST_S(1), b_num_inf); /* Inf/Inf -> 1 */ + + XT_SSIP(quo, S_wr, sz_f32); + } + } + __Pragma("no_reorder"); + + /* + * Part II, polynomial approximation and full quadrant restoration. + * Reference C code: + * + * { + * const union ufloat32uint32 * ptbl; + * float32_t x0, y0, z0, p0; + * int sx, sy; + * + * for ( n=0; n