diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index c31f56fc54..e46aa745df 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -11,7 +11,7 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c - ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_floor_f32_broadcast.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c ) diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 704d29760c..209bc192c8 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -30,17 +30,20 @@ extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ const FLOAT32 * __restrict__ p_inp2, const WORD32 *const p_inp2_shape); -extern "C" WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, +extern "C" WORD32 xa_nn_elm_div_mode_f32xf32_f32(FLOAT32 * __restrict__ p_out, const FLOAT32 * __restrict__ p_inp1, const FLOAT32 * __restrict__ p_inp2, - WORD32 num_elm); - -extern "C" WORD32 xa_nn_elm_div_floor_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, - const WORD32 *const p_out_shape, - const FLOAT32 * __restrict__ p_inp1, - const WORD32 *const p_inp1_shape, - const FLOAT32 * __restrict__ p_inp2, - const WORD32 *const p_inp2_shape); + WORD32 num_elm, + WORD32 mode); + +extern "C" WORD32 xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32( + FLOAT32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const FLOAT32 * __restrict__ p_inp1, + const WORD32 *const p_inp1_shape, + const FLOAT32 * __restrict__ p_inp2, + const WORD32 *const p_inp2_shape, + WORD32 mode); extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, const WORD32 *const p_out_shape, diff --git a/backends/cadence/hifi/operators/op_div.cpp b/backends/cadence/hifi/operators/op_div.cpp index 258ad10883..dc6a22ea4d 100644 --- a/backends/cadence/hifi/operators/op_div.cpp +++ b/backends/cadence/hifi/operators/op_div.cpp @@ -192,6 +192,13 @@ Tensor& div_out_mode( if((broadcast == 1) && (max_dim > NNLIB_MAX_DIM)) fall_back = 1; + int mode_val = -1; + if (mode.has_value() && mode.value() == "trunc") + mode_val = 0; + else if (mode.has_value() && mode.value() == "floor") + mode_val = 1; + else + fall_back = 1; if(!fall_back) { @@ -223,11 +230,11 @@ Tensor& div_out_mode( for(int i = 0; i < b.dim(); i++) inp2_shape[i+off_b] = b.size(i); - xa_nn_elm_div_floor_broadcast_4D_f32xf32_f32(out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape); + xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32(out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape, mode_val); } else { - xa_nn_elm_div_floor_f32xf32_f32(out_data, a_data, b_data, out.numel()); + xa_nn_elm_div_mode_f32xf32_f32(out_data, a_data, b_data, out.numel(), mode_val); } } else diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_floor_f32_broadcast.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c similarity index 73% rename from backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_floor_f32_broadcast.c rename to backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c index ae2d3c0682..95b449f43f 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_floor_f32_broadcast.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c @@ -7,19 +7,21 @@ #if !HAVE_VFPU DISCARD_FUN_FOR_NONVOID_RETURN( - WORD32, xa_nn_elm_floor_div_f32xf32_f32, + WORD32, xa_nn_elm_div_mode_f32xf32_f32, ( FLOAT32 *p_out, const FLOAT32 *p_inp1, const FLOAT32 *p_inp2, - WORD32 num_elm + WORD32 num_elm, + WORD32 mode ) ) #else -WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, +WORD32 xa_nn_elm_div_mode_f32xf32_f32(FLOAT32 * __restrict__ p_out, const FLOAT32 * __restrict__ p_inp1, const FLOAT32 * __restrict__ p_inp2, - WORD32 num_elm) + WORD32 num_elm, + WORD32 mode) { /* NULL pointer checks */ XA_NNLIB_ARG_CHK_PTR(p_out, -1); @@ -31,6 +33,7 @@ WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1); /* Basic Parameter checks */ XA_NNLIB_ARG_CHK_COND((num_elm <= 0), -1); + XA_NNLIB_ARG_CHK_COND(((mode != 0) && (mode != 1)), -1); int i; xtfloatx2 *inp1 = (xtfloatx2 *)p_inp1; @@ -43,6 +46,20 @@ WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, inp2_a = XT_LASX2PP(inp2); out_a = AE_ZALIGN64(); /* Each iteration of loop is independent so safe to use concurrent pragma */ + if(mode == 0) + { +#pragma concurrent /* Each iteration of loop is independent so safe to use concurrent pragma */ + for(i=0;i < num_elm>>1;i++) + { + XT_LASX2IP(x1, inp1_a, inp1); + XT_LASX2IP(x2, inp2_a, inp2); + y = XT_DIV_SX2(x1, x2); + y = FITRUNC_SX2(y); + XT_SASX2IP(y, out_a, out); + } + } + else + { #pragma concurrent for(i=0;i < num_elm>>1;i++) { @@ -52,6 +69,7 @@ WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, y = FIFLOOR_SX2(y); XT_SASX2IP(y, out_a, out); } + } XT_SASX2POSFP(out_a, out); // Remainder Loop @@ -61,6 +79,9 @@ WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, XT_LSIP(a1, (xtfloat *)inp1, 0); XT_LSIP(a2, (xtfloat *)inp2, 0); a = XT_DIV_S(a1, a2); + if(mode == 0) + a = FITRUNC_S(a); + else a = FIFLOOR_S(a); XT_SSI(a, (xtfloat *)out, 0); } @@ -70,12 +91,13 @@ WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out, #endif #if HAVE_VFPU -static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out, +static void internal_elm_div_mode_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out, const FLOAT32 * __restrict__ p_inp1, const FLOAT32 * __restrict__ p_inp2, WORD32 out_lc, WORD32 in_lc, - xtbool sign_flag) + xtbool sign_flag, + WORD32 mode) { int i, j; @@ -109,6 +131,19 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict p_c = (xtfloatx2 *)&p_out[i * in_lc]; if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) { + if(mode == 0) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + y = XT_DIV_SX2(x2, x1); + y = FITRUNC_SX2(y); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { for(j = 0; j < num_simd2_ops; j++) { XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); @@ -118,11 +153,25 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); } } + } else { ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); vinp1 = XT_LASX2PP(p_a); vinp2 = XT_LASX2PP(p_b); + if(mode == 0) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + y = XT_DIV_SX2(x2, x1); + y = FITRUNC_SX2(y); + XT_SASX2IP(y, out_a, p_c); + } + } + else + { for(j = 0; j < num_simd2_ops; j++) { XT_LASX2IP(x1, vinp1, p_a); @@ -131,6 +180,7 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict y = FIFLOOR_SX2(y); XT_SASX2IP(y, out_a, p_c); } + } XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); } if(num_scalar_ops !=0) @@ -138,6 +188,9 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32)); XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32)); c0 = XT_DIV_S(b0, a0); + if(mode == 0) + c0 = FITRUNC_S(c0); + else c0 = FIFLOOR_S(c0); XT_SSI(c0, (xtfloat *)p_c, 0); } @@ -153,6 +206,19 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict p_c = (xtfloatx2 *)&p_out[i * in_lc]; if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) { + if(mode == 0) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + y = XT_DIV_SX2(x1, x2); + y = FITRUNC_SX2(y); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { for(j = 0; j < num_simd2_ops; j++) { XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); @@ -162,12 +228,25 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); } } + }/* if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0))*/ else { ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); vinp1 = XT_LASX2PP(p_a); vinp2 = XT_LASX2PP(p_b); - + if(mode == 0) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + y = XT_DIV_SX2(x1, x2); + y = FITRUNC_SX2(y); + XT_SASX2IP(y, out_a, p_c); + } + } + else + { for(j = 0; j < num_simd2_ops; j++) { XT_LASX2IP(x1, vinp1, p_a); @@ -176,6 +255,7 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict y = FIFLOOR_SX2(y); XT_SASX2IP(y, out_a, p_c); } + } XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); } if(num_scalar_ops !=0) @@ -183,6 +263,9 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32)); XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32)); c0 = XT_DIV_S(a0, b0); + if(mode == 0) + c0 = FITRUNC_S(c0); + else c0 = FIFLOOR_S(c0); XT_SSI(c0, (xtfloat *)p_c, 0); } @@ -190,11 +273,12 @@ static void internal_elm_floor_div_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict } } -static void internal_elm_floor_div_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, +static void internal_elm_div_mode_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, const FLOAT32 * __restrict__ p_inp1, const FLOAT32 * __restrict__ p_inp2, WORD32 num_elm, - xtbool sign_flag) + xtbool sign_flag, + WORD32 mode) { int i; xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1; @@ -212,6 +296,18 @@ static void internal_elm_floor_div_broadcast_f32xf32_f32(FLOAT32 * __restrict__ if(sign_flag){ if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0)) { + if(mode == 0) + { + for(i=0; i