Skip to content

Commit

Permalink
Modified div mod to cover truncate and floor modes
Browse files Browse the repository at this point in the history
  • Loading branch information
dijopaul committed Sep 4, 2024
1 parent f5a4e96 commit fe4290f
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 33 deletions.
2 changes: 1 addition & 1 deletion backends/cadence/hifi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ add_library(
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_floor_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c
)

Expand Down
21 changes: 12 additions & 9 deletions backends/cadence/hifi/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__
const FLOAT32 * __restrict__ p_inp2,
const WORD32 *const p_inp2_shape);

extern "C" WORD32 xa_nn_elm_div_floor_f32xf32_f32(FLOAT32 * __restrict__ p_out,
extern "C" WORD32 xa_nn_elm_div_mode_f32xf32_f32(FLOAT32 * __restrict__ p_out,
const FLOAT32 * __restrict__ p_inp1,
const FLOAT32 * __restrict__ p_inp2,
WORD32 num_elm);

extern "C" WORD32 xa_nn_elm_div_floor_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const FLOAT32 * __restrict__ p_inp1,
const WORD32 *const p_inp1_shape,
const FLOAT32 * __restrict__ p_inp2,
const WORD32 *const p_inp2_shape);
WORD32 num_elm,
WORD32 mode);

extern "C" WORD32 xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32(
FLOAT32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const FLOAT32 * __restrict__ p_inp1,
const WORD32 *const p_inp1_shape,
const FLOAT32 * __restrict__ p_inp2,
const WORD32 *const p_inp2_shape,
WORD32 mode);

extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
Expand Down
11 changes: 9 additions & 2 deletions backends/cadence/hifi/operators/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ Tensor& div_out_mode(

if((broadcast == 1) && (max_dim > NNLIB_MAX_DIM))
fall_back = 1;
int mode_val = -1;
if (mode.has_value() && mode.value() == "trunc")
mode_val = 0;
else if (mode.has_value() && mode.value() == "floor")
mode_val = 1;
else
fall_back = 1;

if(!fall_back)
{
Expand Down Expand Up @@ -223,11 +230,11 @@ Tensor& div_out_mode(
for(int i = 0; i < b.dim(); i++)
inp2_shape[i+off_b] = b.size(i);

xa_nn_elm_div_floor_broadcast_4D_f32xf32_f32(out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32(out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape, mode_val);
}
else
{
xa_nn_elm_div_floor_f32xf32_f32(out_data, a_data, b_data, out.numel());
xa_nn_elm_div_mode_f32xf32_f32(out_data, a_data, b_data, out.numel(), mode_val);
}
}
else
Expand Down
Loading

0 comments on commit fe4290f

Please sign in to comment.